diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 7d578d8b..94ec7f0a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -47,10 +47,10 @@ jobs: env: RELEASE_REPO: hw-native-sys/PTOAS - RELEASE_VER: 0.31 - RELEASE_TAG: v0.31 + RELEASE_VER: 0.36 + RELEASE_TAG: v0.36 CLI_DIR: /installers/ptoas-cli - PTOISA_COMMIT: 0af942568a4f2868673da0a35b0f5b64f27a20d5 + PTOISA_COMMIT: 4e27a104f948e883e0bef44670252381bff794c5 steps: - name: Install system packages diff --git a/docker/Dockerfile b/docker/Dockerfile index 4af438bb..bf02490b 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -15,13 +15,15 @@ RUN pip install --no-cache-dir \ pytest pybind11 nanobind setuptools wheel \ ipython jupyterlab matplotlib pandas -# certain operations need latest isa header, not CANN 8.5.0 default -# header on 2026/04/24 -ARG PTOISA_COMMIT=0af942568a4f2868673da0a35b0f5b64f27a20d5 +# For updated FA style +# https://gitcode.com/cann/pto-isa/commit/4e27a104f948e883e0bef44670252381bff794c5?ref=master +ARG PTOISA_COMMIT=4e27a104f948e883e0bef44670252381bff794c5 WORKDIR /sources -RUN git clone https://gitcode.com/cann/pto-isa.git \ +RUN git clone https://gitcode.com/cann/pto-isa \ && cd pto-isa && git checkout $PTOISA_COMMIT +ENV PTO_LIB_PATH=/sources/pto-isa + # cache above layers unrelated to ptoas version change # change this ununsed arg if need to force rebuild later lines @@ -29,8 +31,10 @@ ARG CACHE_BURST=1 # ARG ARCH=x86_64 ARG ARCH=aarch64 +# https://github.com/hw-native-sys/PTOAS/releases/tag/v0.36 +# include the split pipes https://github.com/hw-native-sys/PTOAS/pull/606 ARG RELEASE_REPO=hw-native-sys/PTOAS -ARG RELEASE_VER=0.31 +ARG RELEASE_VER=0.36 ARG RELEASE_TAG=v${RELEASE_VER} ARG WHEEL_NAME=ptoas-${RELEASE_VER}-cp311-none-manylinux_2_34_${ARCH}.whl ARG CLI_TAR_NAME=ptoas-bin-${ARCH}.tar.gz diff --git a/docker/README.md b/docker/README.md index 8deb808a..3a0f962f 100644 --- a/docker/README.md +++ b/docker/README.md @@ -3,10 +3,10 @@ Recommend using [Ascend Docker Runtime](https://gitcode.com/Ascend/mind-cluster/ Then, build and run docker image: ```bash -RELEASE_VER=0.29 +RELEASE_VER=0.36 sudo docker build \ --build-arg RELEASE_VER=$RELEASE_VER \ - . -t pto_dsl:$RELEASE_VER + . -t pto_dsl:fa_hack # for specific arch (x86_64 vs aarch64) sudo docker build \ @@ -30,7 +30,7 @@ sudo docker run --rm -it --ipc=host --privileged \ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver:ro \ -v /etc/ascend_install.info:/etc/ascend_install.info:ro \ -v $HOME:/mounted_home -w /mounted_home \ - pto_dsl:$RELEASE_VER /bin/bash + pto_dsl:fa_hack /bin/bash ``` ## Appendix: NPU driver diff --git a/examples/aot/flash_attention/cpp_ref/naive_tpush/naive_tpush_dsl_plot.png b/examples/aot/flash_attention/cpp_ref/naive_tpush/naive_tpush_dsl_plot.png new file mode 100644 index 00000000..de04a5f9 Binary files /dev/null and b/examples/aot/flash_attention/cpp_ref/naive_tpush/naive_tpush_dsl_plot.png differ diff --git a/examples/aot/flash_attention/cpp_ref/simplified/fa_compile_and_run_s1_plot.png b/examples/aot/flash_attention/cpp_ref/simplified/fa_compile_and_run_s1_plot.png new file mode 100644 index 00000000..76f48baa Binary files /dev/null and b/examples/aot/flash_attention/cpp_ref/simplified/fa_compile_and_run_s1_plot.png differ diff --git a/examples/aot/flash_attention/cpp_ref/simplified/jit_util_flash.py b/examples/aot/flash_attention/cpp_ref/simplified/jit_util_flash.py index 1fa836cd..4d24b978 100644 --- a/examples/aot/flash_attention/cpp_ref/simplified/jit_util_flash.py +++ b/examples/aot/flash_attention/cpp_ref/simplified/jit_util_flash.py @@ -26,7 +26,7 @@ _CV_FIFO_SIZE = 8 # CV_FIFO_SIZE _CUBE_S0 = 128 # CUBE_S0 _SUPPORTED_TILE_S1 = (256, 512, 1024) -_DEFAULT_TILE_S1 = 256 +_DEFAULT_TILE_S1 = 512 _MAX_TILE_S1 = max(_SUPPORTED_TILE_S1) diff --git a/examples/aot/flash_attention/cpp_ref/simplified/run.py b/examples/aot/flash_attention/cpp_ref/simplified/run.py index 376fc457..c814f01b 100644 --- a/examples/aot/flash_attention/cpp_ref/simplified/run.py +++ b/examples/aot/flash_attention/cpp_ref/simplified/run.py @@ -95,7 +95,7 @@ def fused_attention(q, k, v, is_causal=False): return out.squeeze(0) -def test_flash(tile_s1: int = 256, head: int = 128): +def test_flash(tile_s1: int = 512, head: int = 128): s0 = 128 * 24 s1_values = [1024, 2048, 4096, 8192, 16384, 32768, 64 * 1024, 128 * 1024] bad_s1 = [s1 for s1 in s1_values if s1 % tile_s1 != 0] @@ -156,7 +156,7 @@ def test_flash(tile_s1: int = 256, head: int = 128): print(f"Tile S1 : {tile_s1}") print(f"FLOPs total : {flops_total}") print( - f"JIT flash kernel : {flash_ms:.3f} ms/iter " + f"PTO custom FlashAttention : {flash_ms:.3f} ms/iter " f"({tflops(flops_total, flash_ms):.3f} TFLOP/s)" ) print( @@ -164,20 +164,20 @@ def test_flash(tile_s1: int = 256, head: int = 128): f"({tflops(flops_total, npu_ms):.3f} TFLOP/s)" ) print( - f"torch reference : {ref_ms:.3f} ms/iter " + f"PyTorch Eager Reference : {ref_ms:.3f} ms/iter " f"({tflops(flops_total, ref_ms):.3f} TFLOP/s)" ) torch.testing.assert_close(o_out, o_ref, rtol=1e-3, atol=1e-3) print("vs torch reference: PASSED") torch.testing.assert_close(o_out, o_npu, rtol=1e-3, atol=1e-3) - print("vs npu_fused_attention: PASSED") + print("vs npu_fused_infer_attention_score: PASSED") print("") plot_path = Path(__file__).with_name("fa_compile_and_run_s1_plot.png") plt.figure(figsize=(8, 5)) - plt.plot(s1_values, flash_tflops_values, marker="o", label="flash") - plt.plot(s1_values, ref_tflops_values, marker="o", label="ref") - plt.plot(s1_values, npu_tflops_values, marker="o", label="torch_npu") + plt.plot(s1_values, flash_tflops_values, marker="o", label="PTO custom FlashAttention") + plt.plot(s1_values, ref_tflops_values, marker="o", label="PyTorch Eager Reference") + plt.plot(s1_values, npu_tflops_values, marker="o", label="torch_npu.npu_fused_infer_attention_score") plt.xscale("log", base=2) plt.xticks(s1_values, [str(v) for v in s1_values]) plt.xlabel("S1") @@ -195,7 +195,7 @@ def test_flash(tile_s1: int = 256, head: int = 128): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--tile-s1", type=int, choices=(256, 512, 1024), default=256) + parser.add_argument("--tile-s1", type=int, choices=(256, 512, 1024), default=512) parser.add_argument("--head", type=int, choices=(32, 64, 128), default=128) args = parser.parse_args() test_flash(tile_s1=args.tile_s1, head=args.head) diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/README.md b/examples/aot/flash_attention/cpp_ref/split_pipe/README.md new file mode 100644 index 00000000..2a4cb292 --- /dev/null +++ b/examples/aot/flash_attention/cpp_ref/split_pipe/README.md @@ -0,0 +1,11 @@ +Try different tile sizes (512 gets highest TFLOPs for long sequence) + +```bash +export PTODSL_TEST_DEVICE_ID=7 + +python ./run.py --tile-s1 512 # default +python ./run.py --tile-s1 256 # slower for long seq +python ./run.py --tile-s1 1024 # wrong result +``` + +Reference outputs in [./results](./results) diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/call_kernel_dispatch.cpp b/examples/aot/flash_attention/cpp_ref/split_pipe/call_kernel_dispatch.cpp new file mode 100644 index 00000000..60b1c0d2 --- /dev/null +++ b/examples/aot/flash_attention/cpp_ref/split_pipe/call_kernel_dispatch.cpp @@ -0,0 +1,43 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +*/ + +#include +#include +#include + +#include "fa_performance_kernel.h" +#include "generated_cases.h" +#include "runtime/rt.h" + +extern "C" void call_kernel(void *stream, int headSize, int s0, int s1, int tile_s1, bool is_causal, uint8_t *q, + uint8_t *k, uint8_t *v, uint8_t *o_out, float *qk_tile_fifo, uint16_t *p_tile_fifo, + float *exp_max_ififo, float *pv_tile_fifo, float *global_sum_out, float *exp_max_out, + float *o_parts_out) +{ + if (is_causal) { + return; + } + + uint64_t ffts_val = 0; + uint32_t ffts_len = 0; + rtGetC2cCtrlAddr(&ffts_val, &ffts_len); + auto *ffts = reinterpret_cast(static_cast(ffts_val)); + + uint8_t *cv_comm_buf = nullptr; + +#define LAUNCH_DISPATCH(S0_, HEAD_, S1_, CUBE_S0_, CUBE_S1_, TILE_S1_, QK_PRELOAD_, CAUSAL_MASK_) \ + if (headSize == (HEAD_) && (s0) == (S0_) && (s1) == (S1_) && tile_s1 == (TILE_S1_)) { \ + LaunchTFA<(S0_), (HEAD_), (S1_), (CUBE_S0_), (CUBE_S1_), (TILE_S1_), (QK_PRELOAD_), kFaCvFifoSize, false, \ + (CAUSAL_MASK_), kFaCvFifoConsSyncPeriod>( \ + ffts, reinterpret_cast(q), reinterpret_cast(k), \ + reinterpret_cast(v), reinterpret_cast(p_tile_fifo), exp_max_ififo, \ + global_sum_out, exp_max_out, reinterpret_cast(o_out), o_parts_out, qk_tile_fifo, pv_tile_fifo, \ + reinterpret_cast(stream), cv_comm_buf); \ + return; \ + } + + TFA_FOR_EACH_CASE(LAUNCH_DISPATCH); + +#undef LAUNCH_DISPATCH +} diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/generated_cases.h b/examples/aot/flash_attention/cpp_ref/split_pipe/generated_cases.h new file mode 100644 index 00000000..2973296e --- /dev/null +++ b/examples/aot/flash_attention/cpp_ref/split_pipe/generated_cases.h @@ -0,0 +1,71 @@ +#pragma once +// Auto-generated by scripts/generate_cases.py. Do not edit manually. +// clang-format off +#include + +#define TFA_FOR_EACH_CASE(MACRO) \ + MACRO(3072, 128, 1024, 128, 128, 256, 4, false) \ + MACRO(3072, 128, 1024, 128, 128, 512, 4, false) \ + MACRO(3072, 128, 1024, 128, 128, 1024, 4, false) \ + MACRO(3072, 128, 2048, 128, 128, 256, 4, false) \ + MACRO(3072, 128, 2048, 128, 128, 512, 4, false) \ + MACRO(3072, 128, 2048, 128, 128, 1024, 4, false) \ + MACRO(3072, 128, 4096, 128, 128, 256, 4, false) \ + MACRO(3072, 128, 4096, 128, 128, 512, 4, false) \ + MACRO(3072, 128, 4096, 128, 128, 1024, 4, false) \ + MACRO(3072, 128, 8192, 128, 128, 256, 4, false) \ + MACRO(3072, 128, 8192, 128, 128, 512, 4, false) \ + MACRO(3072, 128, 8192, 128, 128, 1024, 4, false) \ + MACRO(3072, 128, 16384, 128, 128, 256, 4, false) \ + MACRO(3072, 128, 16384, 128, 128, 512, 4, false) \ + MACRO(3072, 128, 16384, 128, 128, 1024, 4, false) \ + MACRO(3072, 128, 32768, 128, 128, 256, 4, false) \ + MACRO(3072, 128, 32768, 128, 128, 512, 4, false) \ + MACRO(3072, 128, 32768, 128, 128, 1024, 4, false) \ + MACRO(3072, 128, 65536, 128, 128, 256, 4, false) \ + MACRO(3072, 128, 65536, 128, 128, 512, 4, false) \ + MACRO(3072, 128, 65536, 128, 128, 1024, 4, false) \ + MACRO(3072, 128, 131072, 128, 128, 256, 4, false) \ + MACRO(3072, 128, 131072, 128, 128, 512, 4, false) \ + MACRO(3072, 128, 131072, 128, 128, 1024, 4, false) + +struct GeneratedTfaCase { + int s0; + int head_size; + int s1; + int cube_s0; + int cube_s1; + int tile_s1; + int qk_preload; + bool causal_mask; + const char *name; +}; + +static constexpr GeneratedTfaCase kGeneratedTfaCases[] = { + {3072, 128, 1024, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_1024"}, + {3072, 128, 1024, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_1024"}, + {3072, 128, 1024, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_1024"}, + {3072, 128, 2048, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_2048"}, + {3072, 128, 2048, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_2048"}, + {3072, 128, 2048, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_2048"}, + {3072, 128, 4096, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_4096"}, + {3072, 128, 4096, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_4096"}, + {3072, 128, 4096, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_4096"}, + {3072, 128, 8192, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_8192"}, + {3072, 128, 8192, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_8192"}, + {3072, 128, 8192, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_8192"}, + {3072, 128, 16384, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_16384"}, + {3072, 128, 16384, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_16384"}, + {3072, 128, 16384, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_16384"}, + {3072, 128, 32768, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_32768"}, + {3072, 128, 32768, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_32768"}, + {3072, 128, 32768, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_32768"}, + {3072, 128, 65536, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_65536"}, + {3072, 128, 65536, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_65536"}, + {3072, 128, 65536, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_65536"}, + {3072, 128, 131072, 128, 128, 256, 4, false, "case_float_H_128_S0_3072_S1_131072"}, + {3072, 128, 131072, 128, 128, 512, 4, false, "case_float_H_128_S0_3072_S1_131072"}, + {3072, 128, 131072, 128, 128, 1024, 4, false, "case_float_H_128_S0_3072_S1_131072"} +}; +static constexpr std::size_t kGeneratedTfaCasesCount = sizeof(kGeneratedTfaCases) / sizeof(kGeneratedTfaCases[0]); +// clang-format on diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/jit_util_flash.py b/examples/aot/flash_attention/cpp_ref/split_pipe/jit_util_flash.py new file mode 100644 index 00000000..f9a846ad --- /dev/null +++ b/examples/aot/flash_attention/cpp_ref/split_pipe/jit_util_flash.py @@ -0,0 +1,228 @@ +#!/usr/bin/python3 +# coding=utf-8 +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# -------------------------------------------------------------------------------- + +"""JIT compile bundled fa_performance_kernel.cpp + call_kernel_dispatch.cpp into flash_jit.so. + +Requires environment variables (CANN / PTO headers — not vendored here): + ASCEND_TOOLKIT_HOME — compiler and acl/runtime includes + PTO_LIB_PATH — PTO headers (, prefetch, sync, …) + +Kernel sources live under this directory: kernels/flash_atten/ +""" + +import ctypes +import os +import subprocess +from pathlib import Path +from typing import List, Optional + +import torch + +ASCEND_TOOLKIT_HOME = os.environ["ASCEND_TOOLKIT_HOME"] +PTO_LIB_PATH = os.environ["PTO_LIB_PATH"] + +_SPLIT_PIPE_DIR = Path(__file__).resolve().parent +_KERNEL_DIR = _SPLIT_PIPE_DIR / "kernels" / "flash_atten" + + +def _pto_include_dir() -> Path: + """PTO_LIB_PATH may be the repo root (contains include/) or the include dir itself.""" + root = Path(PTO_LIB_PATH).resolve() + nested = root / "include" + if nested.is_dir(): + return nested + return root + +_CV_FIFO_SIZE = 8 +_CUBE_S0 = 128 +_CUBE_S1 = 128 +_SUPPORTED_TILE_S1 = (256, 512, 1024) +_DEFAULT_TILE_S1 = 512 + + +def torch_to_ctypes(t: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(t.data_ptr()) + + +def _npu_arch_flag() -> str: + return os.environ.get("NPU_ARCH", "dav-2201").strip() + + +def compile_flash( + kernel_cpp: str, + verbose: bool = False, + timeout: int = 600, + extra_sources: Optional[List[str]] = None, + output_lib: Optional[str] = None, +) -> str: + lib_path = output_lib or str(_SPLIT_PIPE_DIR / "flash_jit.so") + + includes = [ + f"-I{_pto_include_dir()}", + f"-I{_KERNEL_DIR}", + f"-I{_SPLIT_PIPE_DIR}", + f"-I{ASCEND_TOOLKIT_HOME}/include", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc/runtime", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc", + f"-I{ASCEND_TOOLKIT_HOME}/pkg_inc/profiling", + ] + + flags = [ + "-fPIC", + "-shared", + "-xcce", + f"--npu-arch={_npu_arch_flag()}", + "-O2", + "-std=c++17", + "-Wno-ignored-attributes", + *includes, + ] + + sources = [kernel_cpp] + if extra_sources: + sources.extend(extra_sources) + + cmd = ["bisheng", *flags, *sources, "-o", lib_path] + if verbose: + print("compile command:\n", " ".join(cmd)) + + subprocess.run(cmd, check=True, timeout=timeout) + + if verbose: + print(f"generated {lib_path}") + return lib_path + + +def load_flash_lib(lib_path: str, check_type: bool = True): + lib_path = os.path.abspath(lib_path) + lib = ctypes.CDLL(lib_path) + + if check_type: + lib.call_kernel.argtypes = [ + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_bool, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ] + lib.call_kernel.restype = None + + _ws: dict = {} + + def _alloc_workspace(s0: int, head: int, tile_s1: int, device): + shape = (s0, head, tile_s1, str(device)) + if _ws.get("_shape") == shape: + return + + torch.npu.synchronize() + + if tile_s1 not in _SUPPORTED_TILE_S1: + raise ValueError(f"tile_s1 must be one of {_SUPPORTED_TILE_S1}, got {tile_s1}") + + num_s0_blocks = s0 // _CUBE_S0 + slots = num_s0_blocks * _CV_FIFO_SIZE + + _ws.clear() + _ws["_shape"] = shape + _ws["o_out"] = torch.empty((s0, head), device=device, dtype=torch.float32) + + _ws["qk_tile_fifo"] = torch.empty( + (slots, _CUBE_S0, tile_s1), device=device, dtype=torch.float32 + ) + _ws["p_tile_fifo"] = torch.empty( + (slots, _CUBE_S0, tile_s1), device=device, dtype=torch.float16 + ) + _ws["exp_max_ififo"] = torch.empty( + (slots, _CUBE_S0), device=device, dtype=torch.float32 + ) + _ws["pv_tile_fifo"] = torch.empty( + (slots, _CUBE_S0, head), device=device, dtype=torch.float32 + ) + + _ws["global_sum_out"] = torch.empty( + (num_s0_blocks, s0), device=device, dtype=torch.float32 + ) + _ws["exp_max_out"] = torch.empty( + (num_s0_blocks, s0), device=device, dtype=torch.float32 + ) + _ws["o_parts_out"] = torch.empty( + (num_s0_blocks, s0, head), device=device, dtype=torch.float32 + ) + + default_causal = False + default_stream_ptr = torch.npu.current_stream()._as_parameter_ + + def flash( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + stream_ptr=default_stream_ptr, + is_causal=default_causal, + tile_s1: int = _DEFAULT_TILE_S1, + ) -> torch.Tensor: + s1 = k.shape[0] + if s1 % tile_s1 != 0: + raise ValueError(f"S1={s1} must be divisible by tile_s1={tile_s1}") + _alloc_workspace(q.shape[0], q.shape[1], tile_s1, q.device) + + lib.call_kernel( + stream_ptr, + q.shape[1], + q.shape[0], + k.shape[0], + tile_s1, + is_causal, + torch_to_ctypes(q), + torch_to_ctypes(k), + torch_to_ctypes(v), + torch_to_ctypes(_ws["o_out"]), + torch_to_ctypes(_ws["qk_tile_fifo"]), + torch_to_ctypes(_ws["p_tile_fifo"]), + torch_to_ctypes(_ws["exp_max_ififo"]), + torch_to_ctypes(_ws["pv_tile_fifo"]), + torch_to_ctypes(_ws["global_sum_out"]), + torch_to_ctypes(_ws["exp_max_out"]), + torch_to_ctypes(_ws["o_parts_out"]), + ) + return _ws["o_out"] + + return flash + + +def jit_compile_flash( + verbose: bool = False, + clean_up: bool = True, + kernel_cpp: Optional[str] = None, +): + kcpp = kernel_cpp or str(_KERNEL_DIR / "fa_performance_kernel.cpp") + dispatch = str(_SPLIT_PIPE_DIR / "call_kernel_dispatch.cpp") + lib_path = compile_flash( + kcpp, + verbose=verbose, + extra_sources=[dispatch], + output_lib=str(_SPLIT_PIPE_DIR / "flash_jit.so"), + ) + func = load_flash_lib(lib_path) + + if clean_up: + try: + os.remove(lib_path) + except OSError: + pass + + return func diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.cpp b/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.cpp new file mode 100644 index 00000000..d8de88e0 --- /dev/null +++ b/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.cpp @@ -0,0 +1,1015 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include +#include + +#include "fa_performance_kernel.h" +#include +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) +#include +#elif defined(__DAV_C310_CUBE__) || defined(__DAV_C310_VEC__) +#include +#endif +#include "pto_macro_matmul.hpp" +#include "pto_macro_fa_softmax.hpp" +#include "pto_macro_fa_gu.hpp" + +#define UF_ENABLE 1 + +using namespace std; +using namespace pto; + +#ifndef FFTS_BUFFER_FLAG_ENUM +#define FFTS_BUFFER_FLAG_ENUM +// Buffer flag values for FFTS pipeline coordination +enum FftsBufferFlag : uint32_t +{ + BUF0_QK_READY = 0, // Buffer 0: QK data ready + BUF0_SM_CONSUMED = 1, // Buffer 0: Softmax consumed + BUF1_SM_READY = 2, // Buffer 1: Softmax output ready + BUF1_SV_CONSUMED = 3, // Buffer 1: SV consumed + UPDATE_READY = 4, // Update stage ready + UPDATE_CONSUMED = 5, // Update stage consumed + CV_BLOCK_END = 7, // CV comm slot block end (CV_COMM_CTRL reserved in TSyncCVID) +}; +#endif + +enum CoreEvtID : uint32_t +{ + QK_EVENT_ID0, + QK_EVENT_ID1, + PV_EVENT_ID0, + PV_EVENT_ID1, +}; + +// ----------------------------------------------------------------------------- +// Performance tuning knobs (high-level) +// +// The kernel is a cross-core pipeline (Cube + Vec) with explicit FIFOs: +// QK (Cube): compute_qk -> qk_tile_fifo (fp32) +// P (Vec): compute_p -> p_tile_fifo (fp16 x_exp) + l1_exp_max_ififo +// PV (Cube): compute_pv -> pv_tile_fifo (fp32) +// GU (Vec): compute_gu -> o_out (fp32) with running rescale/update +// +// Key knobs that impact throughput (see runTFA<> below): +// - CUBE_S0 / CUBE_S1: tile sizes for QK/PV cube matmuls (compute intensity vs. buffer pressure) +// - qkPreloadNum: pipeline warmup depth (more overlap vs. more L1 FIFO footprint) +// - *_TNBuffers: ping/pong depth for Mat tiles (overlap) and Vec tiles (latency hiding) +// - QKV_CV_FIFO / PV_CV_FIFO: FIFO depth between stages (avoid backpressure) +// ----------------------------------------------------------------------------- + +// Inline macro used for small, performance-sensitive functions +#ifndef PTO_INLINE +#define PTO_INLINE __attribute__((always_inline)) inline +#endif + +// Detect build-time macros and expose as constexpr flags for clearer conditionals +#ifdef __DAV_CUBE__ +constexpr bool DAV_CUBE = true; +#else +constexpr bool DAV_CUBE = false; +#endif + +#ifdef __DAV_VEC__ +constexpr bool DAV_VEC = true; +#else +constexpr bool DAV_VEC = false; +#endif + +constexpr std::size_t MAX_TILE_L1_BYTES = 512U * 1024U; +constexpr std::size_t MAX_VEC_UB_BYTES = 192U * 1024U; + +template +constexpr AICORE std::size_t tile_storage_bytes() +{ + using ElementType = typename TileType::DType; + return static_cast(TileType::Rows * TileType::Cols) * sizeof(ElementType); +} + +template +constexpr AICORE std::size_t tile_buffer_total_bytes() +{ + return tile_storage_bytes() * NumBuffers; +} + +template +AICORE inline uint32_t assign_tile_buffers(TileType (&tiles)[NumBuffers], uint32_t base_offset) +{ + if constexpr (NumBuffers == 0) { + return base_offset; + } + + constexpr std::size_t total_storage_bytes = tile_buffer_total_bytes(); + static_assert(total_storage_bytes <= MAX_TILE_L1_BYTES, "Tile buffer L1 allocation exceeds 512KB"); + + for (std::size_t idx = 0; idx < NumBuffers; ++idx) { + const uint32_t tile_offset = base_offset + static_cast(idx * tile_storage_bytes()); + TASSIGN(tiles[idx], tile_offset); + } + + return base_offset + static_cast(total_storage_bytes); +} + +template +AICORE inline uint32_t assign_tile_buffers_union(TileA (&tilesA)[NumA], TileB (&tilesB)[NumB], uint32_t base_offset) +{ + static_assert(NumA == NumB, "Union assignment expects matching buffer counts"); + if constexpr (NumA == 0) { + return base_offset; + } + + constexpr std::size_t stride_bytes = (tile_storage_bytes() > tile_storage_bytes()) ? + tile_storage_bytes() : + tile_storage_bytes(); + constexpr std::size_t total_storage_bytes = stride_bytes * NumA; + static_assert(total_storage_bytes <= MAX_VEC_UB_BYTES, "Union tile UB allocation exceeds 192KB"); + + for (std::size_t idx = 0; idx < NumA; ++idx) { + const uint32_t tile_offset = base_offset + static_cast(idx * stride_bytes); + TASSIGN(tilesA[idx], tile_offset); + TASSIGN(tilesB[idx], tile_offset); + } + + return base_offset + static_cast(total_storage_bytes); +} + +template +AICORE inline void allocate_cube_tile_buffers(TileQType (&qTiles)[NumQ], TileKType (&kTiles)[NumK], + TilePType (&pTiles)[NumP], TileVType (&vTiles)[NumV]) +{ + constexpr std::size_t total_bytes = + tile_buffer_total_bytes() + tile_buffer_total_bytes() + + tile_buffer_total_bytes() + tile_buffer_total_bytes(); + static_assert(total_bytes <= MAX_TILE_L1_BYTES, "Total cube L1 allocation exceeds 512KB"); + + uint32_t l1_offset = 0; + l1_offset = assign_tile_buffers(qTiles, l1_offset); + l1_offset = assign_tile_buffers(kTiles, l1_offset); + l1_offset = assign_tile_buffers(pTiles, l1_offset); + l1_offset = assign_tile_buffers(vTiles, l1_offset); + (void)l1_offset; +} + +template +AICORE inline void allocate_vec_tile_buffers(TileDataF_T (&srcTiles)[SrcBuffers], ReduceTileF_T &m1_local_max, + TileDataF_T &input_reduce_tmp, ReduceTileF_T &l1_local_sum, + ReduceTileF_T &m2_global_max, ReduceTileF_T &l2_global_sum, + ReduceTileF_T (&l1_exp_max)[ExpMaxBuffers], + TileDataH_T (&x_expT)[XexpBuffers], TileOutT (&pvTile)[pvVecBuffers], + TileOutT &runningOTile, TileDataF_T &triu) +{ + constexpr std::size_t float_tile_bytes = tile_storage_bytes(); + constexpr std::size_t reduce_tile_bytes = tile_storage_bytes(); + constexpr std::size_t xexp_bytes = tile_buffer_total_bytes(); + constexpr std::size_t out_tile_bytes = tile_storage_bytes(); + constexpr std::size_t union_stride = (tile_storage_bytes() > tile_storage_bytes()) ? + tile_storage_bytes() : + tile_storage_bytes(); + static_assert(SrcBuffers == pvVecBuffers, "src/pv ping-pong buffer counts must match for union allocation"); + constexpr std::size_t union_bytes = union_stride * SrcBuffers; + constexpr std::size_t total_bytes = union_bytes + xexp_bytes + (reduce_tile_bytes * (3U + ExpMaxBuffers)) + + (float_tile_bytes / 8 * 1U) + (float_tile_bytes * 1U) + out_tile_bytes; + static_assert(total_bytes <= MAX_VEC_UB_BYTES, "Vec tile UB allocation exceeds 192KB"); + + uint32_t offset = 0; + TASSIGN(runningOTile, offset); + offset += out_tile_bytes; + offset = assign_tile_buffers_union(srcTiles, pvTile, offset); + + TASSIGN(m1_local_max, offset); + offset += static_cast(reduce_tile_bytes); + + TASSIGN(m2_global_max, offset); + offset += static_cast(reduce_tile_bytes); + + uint32_t tmp_float_offset = offset; + TASSIGN(input_reduce_tmp, tmp_float_offset); + offset += static_cast(float_tile_bytes) / 8; + + TASSIGN(triu, offset); + offset += static_cast(float_tile_bytes); + + TASSIGN(l1_local_sum, offset); + offset += static_cast(reduce_tile_bytes); + + TASSIGN(l2_global_sum, offset); + offset += static_cast(reduce_tile_bytes); + + offset = assign_tile_buffers(l1_exp_max, offset); + + uint32_t tail_offset = assign_tile_buffers(x_expT, offset); + + (void)tail_offset; +} + +// Helper to assign an accumulator tile to one of two ping-pong UB addresses (0x0 / 0x10000). +// Keeps a per-type static running index that toggles on every call. Caller may pass +// `initial_id` (0 or 1) to set the starting buffer index on the first call for that tile type. +template +AICORE inline int assign_running_acc_tile(AccTileT &accTile, int initial_id = -1) +{ + static int running_tile_buffer_idx = 0; // per-instantiation running buffer index: 0 -> base0, 1 -> base1 + if (initial_id == 0 || initial_id == 1) { + running_tile_buffer_idx = initial_id; + } + const int id = running_tile_buffer_idx; + const uint32_t base_addr = (id == 0) ? 0x0u : 0x10000u; + TASSIGN(accTile, base_addr); + running_tile_buffer_idx ^= 1; // toggle for next call + return id; +} + +template +AICORE inline void compute_qk(QKPipe &qkPipe, int tile_id, int sub_tile_id, __gm__ half *q, __gm__ half *k, + __gm__ float *qk_tile_fifo, TileMatQData &qMatTile, TileMatKData &kMatTile, + TileQKData &qkAccTile, QKSlotGlobal &qkSlotGlobal, uint64_t qkMatTileEventId, + int accTileEvtID, int blk_idx) +{ + if constexpr (DAV_CUBE) { + constexpr uint32_t Cube_S0 = CUBE_S0; + constexpr uint32_t Cube_S1 = CUBE_S1; + constexpr uint32_t Tile_S1 = TILE_S1; + constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1; + constexpr uint32_t Cube_HEAD = HEAD_SIZE; + + constexpr int QKP_CV_FIFO = QKPipe::RingFiFo::SLOT_NUM; + static_assert(QKP_CV_FIFO >= 1, "QKP_CV_FIFO must be >= 1"); + static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1"); + + const int s0_index = blk_idx * CUBE_S0; + const int s1_index = tile_id * static_cast(Tile_S1) + sub_tile_id * static_cast(Cube_S1); + if (sub_tile_id == 0) { + TALLOC(qkPipe, qkSlotGlobal); + } + if constexpr (CAUSAL_MASK) { + if (s1_index > s0_index) { + if (sub_tile_id == static_cast(kTileFactor) - 1) { + TPUSH(qkPipe, qkSlotGlobal); + } + return; + } + } + using GlobalDataQ = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + using GlobalDataK = GlobalTensor, + pto::Stride<1, 1, 1, 1, HEAD_SIZE>, Layout::DN>; // BNSD - (N, K) layout + + GlobalDataQ qGlobal(q); + GlobalDataK kGlobal(k + s1_index * HEAD_SIZE); + + wait_flag(PIPE_MTE1, PIPE_MTE2, qkMatTileEventId); + + if (tile_id == 0 && sub_tile_id == 0) { + TLOAD(qMatTile, qGlobal); + } + + TLOAD(kMatTile, kGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + +#if UF_ENABLE + pto_macro_matmul(qMatTile, kMatTile, qkAccTile, AccMode::InitFinalSum); +#else + wait_flag(PIPE_FIX, PIPE_M, accTileEvtID); + pto_macro_matmul(qMatTile, kMatTile, qkAccTile, AccMode::Init); +#endif + + set_flag(PIPE_MTE1, PIPE_MTE2, qkMatTileEventId); +#if !UF_ENABLE + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); +#endif + + using QKStoreGlobal = + GlobalTensor, pto::Stride<1, 1, 1, Cube_S1, 1>>; + const uint32_t buf_idx = static_cast(tile_id % QKP_CV_FIFO); + const size_t base_elems = + static_cast(buf_idx) * static_cast(kTileFactor) * static_cast(Cube_S0) * + static_cast(Cube_S1) + + static_cast(sub_tile_id) * static_cast(Cube_S0) * static_cast(Cube_S1); + QKStoreGlobal qkStoreGlobal(qk_tile_fifo + base_elems); +#if UF_ENABLE + TSTORE(qkStoreGlobal, qkAccTile); +#else + TSTORE(qkStoreGlobal, qkAccTile); +#endif + + if (sub_tile_id == static_cast(kTileFactor) - 1) { + TPUSH(qkPipe, qkSlotGlobal); + } + +#if !UF_ENABLE + set_flag(PIPE_FIX, PIPE_M, accTileEvtID); +#endif + } +} + +template +AICORE inline void compute_pv(PPipe &pPipe, PVPipe &pvPipe, int tile_id, int sub_tile_id, __gm__ half *v, + __gm__ half *p_tile_fifo, TileMatPData &pMatTile, TileMatVData &vMatTile, + TilePVData &pvAccTile, PSlotGlobal &pSlotGlobal, PVSlotGlobal &pvSlotGlobal, + uint64_t svMatTileEventId, int accTileEvtID, int blk_idx) +{ + constexpr uint32_t Cube_S0 = CUBE_S0; + constexpr uint32_t Cube_S1 = CUBE_S1; + constexpr uint32_t Tile_S1 = TILE_S1; + constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1; + constexpr uint32_t Cube_HEAD = HEAD_SIZE; + constexpr uint32_t TileElems = Cube_S0 * Tile_S1; + constexpr int QKP_CV_FIFO = PVPipe::RingFiFo::SLOT_NUM; + static_assert(QKP_CV_FIFO >= 1, "PV_CV_FIFO must be >= 1"); + static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1"); + + const int s0_index = blk_idx * Cube_S0; + const int s1_index = tile_id * static_cast(Tile_S1) + sub_tile_id * static_cast(Cube_S1); + const bool is_last_subtile = (sub_tile_id + 1 == static_cast(kTileFactor)); + const bool next_will_be_skipped = (s1_index + static_cast(Cube_S1)) > s0_index && CAUSAL_MASK; + + if constexpr (DAV_CUBE) { + if (sub_tile_id == 0) { + TPOP(pPipe, pSlotGlobal); + } + if constexpr (CAUSAL_MASK) { + if (s1_index > s0_index) { + if (is_last_subtile) { + TFREE(pPipe, pSlotGlobal); + } + return; + } + } + + using GlobalVT = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + + wait_flag(PIPE_MTE1, PIPE_MTE2, svMatTileEventId); + + GlobalVT vLoad((__gm__ half *)(v + s1_index * HEAD_SIZE)); + TLOAD(vMatTile, vLoad); + + using PLoadGlobal = GlobalTensor, pto::Stride<1, 1, 1, Cube_S1, 1>>; + const uint32_t buf_idx = static_cast(tile_id % PPipe::RingFiFo::SLOT_NUM); + const size_t base_elems = + static_cast(buf_idx) * static_cast(Cube_S0) * static_cast(Tile_S1) + + static_cast(sub_tile_id) * static_cast(Cube_S0) * static_cast(Cube_S1); + PLoadGlobal pLoadGlobal(p_tile_fifo + base_elems); + TLOAD(pMatTile, pLoadGlobal); + if (is_last_subtile) { + TFREE(pPipe, pSlotGlobal); + } + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + +#if !UF_ENABLE + if (sub_tile_id == 0) { + wait_flag(PIPE_FIX, PIPE_M, accTileEvtID); + } +#endif + +#if UF_ENABLE + const AccMode accMode = + (sub_tile_id == 0) ? + (is_last_subtile || next_will_be_skipped ? AccMode::InitFinalSum : AccMode::InitPartialSum) : + (is_last_subtile || next_will_be_skipped ? AccMode::AccFinalSum : AccMode::AccPartialSum); + pto_macro_matmul(pMatTile, vMatTile, pvAccTile, accMode); +#else + const AccMode accMode = (sub_tile_id == 0) ? AccMode::Init : AccMode::Acc; + pto_macro_matmul(pMatTile, vMatTile, pvAccTile, accMode); +#endif + + set_flag(PIPE_MTE1, PIPE_MTE2, svMatTileEventId); + + if (sub_tile_id == static_cast(kTileFactor) - 1 || next_will_be_skipped) { +#if !UF_ENABLE + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); +#endif + + using PVStoreGlobal = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + TALLOC(pvPipe, pvSlotGlobal); + PVStoreGlobal pvStoreGlobal(pvSlotGlobal.data()); +#if UF_ENABLE + TSTORE(pvStoreGlobal, pvAccTile); +#else + TSTORE(pvStoreGlobal, pvAccTile); +#endif + TPUSH(pvPipe, pvSlotGlobal); + +#if !UF_ENABLE + set_flag(PIPE_FIX, PIPE_M, accTileEvtID); +#endif + } // end loop + } // end if DAV_CUBE +} + +template +AICORE inline void compute_p(QKPipe &qkPipe, PPipe &pPipe, int tile_id, int row_slice, __gm__ float *exp_max_ififo, + __gm__ float *qk_tile_fifo, __gm__ half *p_tile_fifo, __gm__ float *global_sum_out, + __gm__ float *exp_max_out, TileDataF_T &qkVecTile, TileDataH_T &x_expT, + TileDataF_T &input_reduce_tmp, ReduceTileF_T &m1_local_max, ReduceTileF_T &l1_local_sum, + ReduceTileF_T &m2_global_max, ReduceTileF_T &l2_global_sum, + ReduceTileF_T &l1_exp_max_ififo, TileDataF_T triu, QKVecSlotGlobal &qkVecSlotGlobal, + PVecSlotGlobal &pVecSlotGlobal, uint64_t pTileEventId, int blk_idx) +{ + constexpr uint32_t Cube_S0 = CUBE_S0; + constexpr uint32_t Cube_S1 = CUBE_S1; + constexpr uint32_t Tile_S1 = TILE_S1; + constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1; + constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor; + constexpr int QKP_CV_FIFO = QKPipe::RingFiFo::SLOT_NUM; + static_assert(QKP_CV_FIFO >= 1, "QKP_CV_FIFO must be >= 1"); + static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1"); + static_assert(Cube_S0 % (VEC_CORES * kTileFactor) == 0, "Vec rows must divide evenly across tile slices"); + const bool initFlag = (tile_id == 0); + if constexpr (DAV_VEC) { + const size_t subblock_base_rows = + static_cast(Cube_S0 / VEC_CORES) * static_cast(get_subblockid()); + const size_t local_row_offset = static_cast(row_slice * Vec_S0); + const size_t row_offset = subblock_base_rows + local_row_offset; + const int s0_index = blk_idx * Cube_S0 + row_offset; + const int s1_index = tile_id * static_cast(Tile_S1); + wait_flag(PIPE_V, PIPE_MTE2, pTileEventId); + + if (row_slice == 0) { + TPOP(qkPipe, qkVecSlotGlobal); + } + + const uint32_t qk_buf_idx = static_cast(tile_id % QKP_CV_FIFO); + const size_t qk_base_elems = static_cast(qk_buf_idx) * static_cast(kTileFactor) * + static_cast(Cube_S0) * static_cast(Cube_S1); + __gm__ float *qk_ptr = qk_tile_fifo + qk_base_elems + row_offset * static_cast(Cube_S1); + using QKLoadGlobal = + GlobalTensor, pto::Stride<1, 1, 1, Cube_S1, 1>>; + using TileDataFSub = Tile; + for (int sub_col = 0; sub_col < static_cast(kTileFactor); ++sub_col) { + QKLoadGlobal qkLoadGlobal(qk_ptr + static_cast(sub_col) * static_cast(Cube_S0) * + static_cast(Cube_S1)); + TileDataFSub qkVecSub; + TASSIGN(qkVecSub, (uint64_t)qkVecTile.data() + + static_cast(sub_col) * static_cast(Cube_S1) * sizeof(float)); + TLOAD(qkVecSub, qkLoadGlobal); + } + if (row_slice == static_cast(kTileFactor) - 1) { + TFREE(qkPipe, qkVecSlotGlobal); + } + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Extract per-slice views into the per-core reduce tiles so each slice writes into its row range + using ReduceSliceTile = Tile; + // reduce tiles live per vector core; offset only by row_slice within the core (no subblock stride) + const size_t reduce_slice_rows = static_cast(row_slice * Vec_S0); + const uint64_t reduce_row_byte_offset = reduce_slice_rows * sizeof(float); + + ReduceSliceTile m1_local_max_slice; + ReduceSliceTile l1_local_sum_slice; + ReduceSliceTile m2_global_max_slice; + ReduceSliceTile l2_global_sum_slice; + ReduceSliceTile l1_exp_max_slice; + + TASSIGN(m1_local_max_slice, (uint64_t)m1_local_max.data() + reduce_row_byte_offset); + TASSIGN(l1_local_sum_slice, (uint64_t)l1_local_sum.data() + reduce_row_byte_offset); + TASSIGN(m2_global_max_slice, (uint64_t)m2_global_max.data() + reduce_row_byte_offset); + TASSIGN(l2_global_sum_slice, (uint64_t)l2_global_sum.data() + reduce_row_byte_offset); + TASSIGN(l1_exp_max_slice, (uint64_t)l1_exp_max_ififo.data() + reduce_row_byte_offset); + + // Extract current slice state from full-length reduce tiles + wait_flag(PIPE_MTE3, PIPE_V, pTileEventId); + if (initFlag) { + pto_macro_fa_softmax( + x_expT, qkVecTile, m1_local_max_slice, l1_local_sum_slice, m2_global_max_slice, l2_global_sum_slice, + l1_exp_max_slice, input_reduce_tmp, qkVecTile, triu, s0_index, s1_index); + } else { + pto_macro_fa_softmax( + x_expT, qkVecTile, m1_local_max_slice, l1_local_sum_slice, m2_global_max_slice, l2_global_sum_slice, + l1_exp_max_slice, input_reduce_tmp, qkVecTile, triu, s0_index, s1_index); + } + + set_flag(PIPE_V, PIPE_MTE2, pTileEventId); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + if (row_slice == 0) { + TALLOC(pPipe, pVecSlotGlobal); + } + using PStoreGlobal = GlobalTensor, pto::Stride<1, 1, 1, Cube_S1, 1>>; + using TileDataHSub = Tile; + __gm__ half *p_ptr = p_tile_fifo + qk_base_elems + row_offset * static_cast(Cube_S1); + for (int sub_col = 0; sub_col < static_cast(kTileFactor); ++sub_col) { + PStoreGlobal pStoreGlobal(p_ptr + static_cast(sub_col) * static_cast(Cube_S0) * + static_cast(Cube_S1)); + TileDataHSub xExpSub; + TASSIGN(xExpSub, (uint64_t)x_expT.data() + + static_cast(sub_col) * static_cast(Cube_S1) * sizeof(half)); + TSTORE(pStoreGlobal, xExpSub); + } + if (row_slice == static_cast(kTileFactor) - 1) { + TPUSH(pPipe, pVecSlotGlobal); + } + + set_flag(PIPE_MTE3, PIPE_V, pTileEventId); + if constexpr (INTERMEDIATE_CHECK) { + // On the final row_slice, emit the exp_max for this subblock only (Cube_S0 / VEC_CORES rows) + if (row_slice == static_cast(kTileFactor) - 1) { + constexpr uint32_t SubblockRows = Cube_S0 / VEC_CORES; + using GlobalPMaxFloatSub = + GlobalTensor, pto::Stride<1, 1, 1, Cube_S0, 1>>; + using ExpMaxSub = Tile; + const size_t base_elems_pmax = + static_cast(tile_id % QKP_CV_FIFO) * static_cast(Cube_S0) + subblock_base_rows; + __gm__ float *p_ptr_fp32 = exp_max_ififo + base_elems_pmax; + GlobalPMaxFloatSub pMaxGlobal(p_ptr_fp32); + ExpMaxSub l1_exp_max_rowmajor; + TRESHAPE(l1_exp_max_rowmajor, l1_exp_max_ififo); + TSTORE(pMaxGlobal, l1_exp_max_rowmajor); + } + } + } +} + +template +AICORE inline void compute_gu(PVPipe &pvPipe, int tile_id, int num_tiles, __gm__ float *o_out, + __gm__ float *o_parts_out, TileOutT &runningOTile, TileOutT &pvVecTile, + ReduceTileF_T &l1_exp_max_ififo, ReduceTileF_T &l2_global_sum, uint64_t guEventId) +{ + constexpr uint32_t Cube_S0 = CUBE_S0; + constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES; + + if constexpr (DAV_VEC) { + wait_flag(PIPE_V, PIPE_MTE2, guEventId); + const size_t subblock_base_rows = + static_cast(Cube_S0 / VEC_CORES) * static_cast(get_subblockid()); + + using PVVecGlobal = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + PVVecGlobal pvGlobal; + TPOP(pvPipe, pvGlobal); + + if (tile_id == 0) { + TLOAD(runningOTile, pvGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + if constexpr (CAUSAL_MASK) { + if (tile_id == num_tiles - 1) + pto_macro_fa_gu_single_and_last_tile(runningOTile, l2_global_sum); + } + } else { + TLOAD(pvVecTile, pvGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + if (tile_id < num_tiles - 1) { + pto_macro_fa_gu(runningOTile, pvVecTile, l1_exp_max_ififo); + } else { + pto_macro_fa_gu_last(runningOTile, pvVecTile, l1_exp_max_ififo, l2_global_sum); + } + } + TFREE(pvPipe, pvGlobal); + + set_flag(PIPE_V, PIPE_MTE2, guEventId); + + if (tile_id == num_tiles - 1) { + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + using GlobalOutT = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + GlobalOutT outGlobal((__gm__ float *)(o_out + subblock_base_rows * HEAD_SIZE)); + TSTORE(outGlobal, runningOTile); + } + } +} + +template +__global__ AICORE void runTFA(__gm__ uint64_t *ffts_addr, __gm__ half *q, __gm__ half *k, __gm__ half *v, + __gm__ half *p_tile_fifo, __gm__ float *exp_max_ififo, __gm__ float *global_sum_out, + __gm__ float *exp_max_out, __gm__ float *o_out, __gm__ float *o_parts_out, + __gm__ float *qk_tile_fifo, __gm__ float *pv_tile_fifo, __gm__ uint8_t *cv_comm_buf, + __gm__ uint8_t *profile_buf) +{ + uint64_t tStart = get_sys_cnt(); + + set_ffts_base_addr((uint64_t)ffts_addr); + if constexpr (DAV_CUBE) { + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + } + + // Rename dimensions for clarity: S0 (rows total), Cube_S0 (per-block rows), S1 (cols), HEAD_SIZE (inner) + constexpr uint32_t Cube_S0 = CUBE_S0; + constexpr uint32_t block_rows = S0 / CUBE_S0; + constexpr uint32_t Cube_S1 = CUBE_S1; // per-tile S1 chunk + constexpr uint32_t Tile_S1 = TILE_S1; // logical tile along S1 + static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1"); + constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1; // sub-tiles per TILE_S1 + constexpr uint32_t Cube_HEAD = HEAD_SIZE; + constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor; + constexpr uint32_t VecGuRows = Cube_S0 / VEC_CORES; + static_assert(Cube_S0 % (VEC_CORES * kTileFactor) == 0, "Vec rows must divide evenly across tile slices"); + + // -------------------------- + // Tuning knobs (pipeline) + // + // qkPreloadNum controls how many (QK -> P) tiles we warm up before entering the steady-state loop. + // - Larger preload improves overlap (Cube/VEC concurrency) for long S1. + // - Larger preload increases FIFO footprint (qkGlobalTensorNBuffers / pvGlobalTensorNBuffers / + // guGlobalTensorNBuffers). + constexpr uint32_t qkPreloadNum = QK_PRELOAD; + + // Buffer counts for optional double-buffering (default 1) + // - srcVecTNBuffers/xexpVecTNBuffers: Vec ping-pong for QK load and x_exp output + // - *MatTNBuffers: L1 ping-pong for Cube stage (K/P/V) + // Keep these small (1-2) unless you have measured stall bubbles that require deeper buffering. + constexpr uint32_t srcVecTNBuffers = 2; + constexpr uint32_t xexpVecTNBuffers = 2; + constexpr uint32_t outOTileNBuffers = 2; + constexpr uint32_t qMatTNBuffers = 1; + constexpr uint32_t kMatTNBuffers = 2; + constexpr uint32_t pMatTNBuffers = 2; + constexpr uint32_t vMatTNBuffers = 2; + constexpr uint32_t qkp_tile_fifo_size = CV_FIFO_SIZE; + constexpr uint32_t pv_tile_fifo_size = CV_FIFO_SIZE; + static_assert(qkPreloadNum >= 1, "qkPreloadNum must be >= 1"); + static_assert(CV_FIFO_CONS_SYNC_PERIOD >= 1, "CV_FIFO_CONS_SYNC_PERIOD must be >= 1"); + static_assert((qkPreloadNum > 1) || (kTileFactor == 1), "qkPreloadNum must be > 1 unless kTileFactor == 1"); + + // Define tile types for first QK matmul + using TileMatQData = + Tile; + using TileMatKData = + Tile; + // Accumulator rows must match Cube_S0 (per-block rows), not logical S0 + using TileQKData = TileAcc; + + TileMatQData qMatTile[qMatTNBuffers]; + TileMatKData kMatTile[kMatTNBuffers]; + TileQKData qkAccTile; + + // Define tile types for second PV matmul + using TileMatPData = + Tile; + using TileMatVData = + Tile; + using TilePVData = TileAcc; + + TileMatPData pMatTile[pMatTNBuffers]; + TileMatVData vMatTile[vMatTNBuffers]; + TilePVData pvAccTile; + + allocate_cube_tile_buffers(qMatTile, kMatTile, pMatTile, vMatTile); + + // Assign accumulator tiles using ping-pong helper. qk starts at 0, pv starts at 1. + assign_running_acc_tile(qkAccTile, 0); + assign_running_acc_tile(pvAccTile, 1); + + // Define tile types for FA softmax P computation + // UB offsets for softmax tiles + // Define per-tile vector tiles sized to Cube_S1 + using TileDataF_T = Tile; + using TileDataH_T = Tile; + constexpr uint32_t SubblockRows = Cube_S0 / VEC_CORES; + // Reduce tiles cover one vector core's rows (Cube_S0 / VEC_CORES); slices are extracted per row_slice + using ReduceTileF_T = Tile; + + TileDataF_T qkVecTile[srcVecTNBuffers]; + ReduceTileF_T m1_local_max; + TileDataF_T input_reduce_tmp; + TileDataF_T triu; + ReduceTileF_T l1_local_sum; + ReduceTileF_T m2_global_max; + ReduceTileF_T l2_global_sum; + ReduceTileF_T l1_exp_max_ififo[qkp_tile_fifo_size]; + TileDataH_T x_expT[xexpVecTNBuffers]; + + using TileOutGuT = Tile; + TileOutGuT pvVecTile[outOTileNBuffers]; + TileOutGuT runningOTile; + allocate_vec_tile_buffers(qkVecTile, m1_local_max, input_reduce_tmp, l1_local_sum, m2_global_max, + l2_global_sum, l1_exp_max_ififo, x_expT, pvVecTile, runningOTile, triu); + + // block offset for logical S0 +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) // A5 defined macro, don't need to reassign + const int block_idx = get_block_idx(); +#endif + const int block_offset_rows = block_idx * static_cast(Cube_S0); + + constexpr bool use_cv_comm = (!INTERMEDIATE_CHECK) && (block_rows >= static_cast(pto::kCvMaxCores)); + int comm_slot = block_idx; + + if constexpr (use_cv_comm) { + comm_slot = pto::TSYNC_CVID(block_idx, cv_comm_buf); + } + __gm__ uint64_t *profile_entry = nullptr; + if (profile_buf != nullptr) { + std::size_t profile_block_base = static_cast(block_idx) * kFaProfileBytesPerBlock; + std::size_t profile_offset = profile_block_base; + if constexpr (DAV_VEC) { + profile_offset += + (static_cast(get_subblockid()) + 1U) * 1024U; // vec subblock 0/1 use 2nd/3rd KB + } + profile_entry = reinterpret_cast<__gm__ uint64_t *>(profile_buf + profile_offset); + profile_entry[0] = tStart; + } + const size_t p_fifo_block_stride = + static_cast(qkp_tile_fifo_size) * static_cast(Cube_S0) * static_cast(Tile_S1); + const size_t p_max_fifo_block_stride = static_cast(qkp_tile_fifo_size) * static_cast(Cube_S0); + const size_t qk_fifo_block_stride = p_fifo_block_stride; + const size_t pv_fifo_block_stride = + static_cast(pv_tile_fifo_size) * static_cast(Cube_S0) * static_cast(HEAD_SIZE); + + __gm__ half *q_block = q + block_offset_rows * HEAD_SIZE; + __gm__ half *p_tile_fifo_block = p_tile_fifo + static_cast(comm_slot) * p_fifo_block_stride; + __gm__ float *exp_max_ififo_block = exp_max_ififo + static_cast(comm_slot) * p_max_fifo_block_stride; + __gm__ float *global_sum_block = global_sum_out + block_offset_rows; + __gm__ float *exp_max_block = exp_max_out + block_offset_rows; + __gm__ float *o_out_block = o_out + static_cast(block_offset_rows) * static_cast(HEAD_SIZE); + __gm__ float *o_parts_block = o_parts_out + static_cast(block_offset_rows) * static_cast(HEAD_SIZE); + __gm__ float *qk_tile_fifo_block = qk_tile_fifo + static_cast(comm_slot) * qk_fifo_block_stride; + __gm__ float *pv_tile_fifo_block = pv_tile_fifo + static_cast(comm_slot) * pv_fifo_block_stride; + + int num_tiles_s1 = S1 / Tile_S1; + if constexpr (CAUSAL_MASK) + num_tiles_s1 = (1 + ((block_idx * CUBE_S0) / Tile_S1)); + if constexpr (DAV_CUBE) { + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + } + if constexpr (DAV_VEC) { + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + } + + int p_gu_src_pingpong_id = 0; // shared ping-pong for softmax vec tiles, pv output tiles, and GU input tiles + int k_src_pingpong_id = 0; // separate ping-pong for K tiles + int pv_src_pingpong_id = 0; // separate ping-pong for P V tiles + + int qkAccTileEvtID = 0; + int pvAccTileEvtID = 0; + + // FIFO definitions + constexpr uint8_t FiFoDepth = CV_FIFO_SIZE; +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) + constexpr uint8_t QK_PIPE_DIR = Direction::DIR_C2V; + constexpr uint8_t P_PIPE_DIR = Direction::DIR_V2C; + constexpr uint8_t PV_PIPE_DIR = Direction::DIR_C2V; +#elif defined(__DAV_C310_CUBE__) || defined(__DAV_C310_VEC__) + constexpr uint8_t QK_PIPE_DIR = Direction::DIR_C2V_GM; + constexpr uint8_t P_PIPE_DIR = Direction::DIR_V2C_GM; + constexpr uint8_t PV_PIPE_DIR = Direction::DIR_C2V_GM; +#endif + + using QKPipe = TPipe; + QKPipe qkPipe(qk_tile_fifo_block, (uint32_t)(uint64_t)qkVecTile[0].data(), 0x0); + + // pFiFo, pProd, pCons + using PPipe = TPipe; + PPipe pPipe(p_tile_fifo_block, 0x0, (uint32_t)(uint64_t)pMatTile[0].data()); + + // pvFiFo, pvProd, pvCons + using PVPipe = TPipe; + PVPipe pvPipe(pv_tile_fifo_block, (uint32_t)(uint64_t)pvVecTile[0].data(), 0x0); + + using QKSlotGlobal = GlobalTensor, pto::Stride<1, 1, 1, Tile_S1, 1>>; + using QKVecSlotGlobal = + GlobalTensor, pto::Stride<1, 1, 1, Tile_S1, 1>>; + using PSlotGlobal = GlobalTensor, pto::Stride<1, 1, 1, Tile_S1, 1>>; + using PVecSlotGlobal = + GlobalTensor, pto::Stride<1, 1, 1, Tile_S1, 1>>; + using PVSlotGlobal = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + QKSlotGlobal qkSlotGlobal; + QKVecSlotGlobal qkVecSlotGlobal; + PSlotGlobal pSlotGlobal; + PVecSlotGlobal pVecSlotGlobal; + PVSlotGlobal pvSlotGlobal; + + // QK and P pre-computation (tile_id based) + for (int preload_tile = 0; preload_tile < static_cast(qkPreloadNum) && preload_tile < num_tiles_s1; + ++preload_tile) { + if constexpr (DAV_CUBE) { + for (int sub_tile = 0; sub_tile < static_cast(kTileFactor); ++sub_tile) { + qkAccTileEvtID = assign_running_acc_tile(qkAccTile); + compute_qk( + qkPipe, preload_tile, sub_tile, q_block, k, qk_tile_fifo_block, qMatTile[0], + kMatTile[k_src_pingpong_id % kMatTNBuffers], qkAccTile, qkSlotGlobal, + k_src_pingpong_id % kMatTNBuffers, qkAccTileEvtID, block_idx); + k_src_pingpong_id++; + } + } + if constexpr (DAV_VEC) { + for (int row_slice = 0; row_slice < static_cast(kTileFactor); ++row_slice) { + // Init only on the very first S1 tile; row_slice partitions rows within that tile + compute_p( + qkPipe, pPipe, preload_tile, row_slice, exp_max_ififo_block, qk_tile_fifo_block, p_tile_fifo_block, + global_sum_block, exp_max_block, qkVecTile[p_gu_src_pingpong_id % srcVecTNBuffers], + x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], input_reduce_tmp, m1_local_max, l1_local_sum, + m2_global_max, l2_global_sum, l1_exp_max_ififo[preload_tile % qkp_tile_fifo_size], triu, + qkVecSlotGlobal, pVecSlotGlobal, p_gu_src_pingpong_id % xexpVecTNBuffers, block_idx); + p_gu_src_pingpong_id++; + } + } + } + + for (int tile_id = 0; tile_id < num_tiles_s1; ++tile_id) { + int next_qk_tile = (tile_id + static_cast(qkPreloadNum) >= num_tiles_s1) ? + -1 : + (tile_id + static_cast(qkPreloadNum)); + + if (next_qk_tile != -1) + qkAccTileEvtID = assign_running_acc_tile(qkAccTile); + pvAccTileEvtID = assign_running_acc_tile(pvAccTile); + + for (int sub_tile = 0; sub_tile < static_cast(kTileFactor); ++sub_tile) { + if constexpr (DAV_CUBE) { + if (next_qk_tile != -1) { + compute_qk( + qkPipe, next_qk_tile, sub_tile, q_block, k, qk_tile_fifo_block, qMatTile[0], + kMatTile[k_src_pingpong_id % kMatTNBuffers], qkAccTile, qkSlotGlobal, + k_src_pingpong_id % kMatTNBuffers, qkAccTileEvtID, block_idx); + k_src_pingpong_id++; + } + } + + if constexpr (DAV_VEC) { + if (next_qk_tile != -1) { + compute_p( + qkPipe, pPipe, next_qk_tile, sub_tile, exp_max_ififo_block, qk_tile_fifo_block, + p_tile_fifo_block, global_sum_block, exp_max_block, + qkVecTile[p_gu_src_pingpong_id % srcVecTNBuffers], + x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], input_reduce_tmp, m1_local_max, l1_local_sum, + m2_global_max, l2_global_sum, l1_exp_max_ififo[next_qk_tile % qkp_tile_fifo_size], triu, + qkVecSlotGlobal, pVecSlotGlobal, p_gu_src_pingpong_id % xexpVecTNBuffers, block_idx); + p_gu_src_pingpong_id++; + } + } + + if constexpr (DAV_CUBE) { + compute_pv( + pPipe, pvPipe, tile_id, sub_tile, v, p_tile_fifo_block, + pMatTile[pv_src_pingpong_id % pMatTNBuffers], vMatTile[pv_src_pingpong_id % vMatTNBuffers], + pvAccTile, pSlotGlobal, pvSlotGlobal, pv_src_pingpong_id % vMatTNBuffers + PV_EVENT_ID0, + pvAccTileEvtID, block_idx); + pv_src_pingpong_id++; + } + } + + if constexpr (DAV_VEC) { + compute_gu( + pvPipe, tile_id, num_tiles_s1, o_out_block, o_parts_block, runningOTile, + pvVecTile[p_gu_src_pingpong_id % outOTileNBuffers], l1_exp_max_ififo[tile_id % qkp_tile_fifo_size], + l2_global_sum, p_gu_src_pingpong_id % outOTileNBuffers); + p_gu_src_pingpong_id++; + } + } + + if constexpr (DAV_CUBE) { + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + } + + if constexpr (DAV_VEC) { + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + } + + pipe_barrier(PIPE_ALL); + uint64_t tEnd = get_sys_cnt(); + if (profile_entry != nullptr) { + profile_entry[1] = tEnd; + } +#ifdef _DEBUG + if constexpr (DAV_CUBE) { + cce::printf("Core %d Cube Block %d, Start @%d End @%d (%d us)\n", get_coreid(), block_idx, int(tStart), + int(tEnd), int(tEnd - tStart) * 20 / 1000); + } else { + cce::printf("Core %d Vec Block %d, SubBlock %d, Start @%d End @%d (%d us)\n", get_coreid(), block_idx, + int(get_subblockid()), int(tStart), int(tEnd), int(tEnd - tStart) * 20 / 1000); + } +#endif +} + +// Empty kernel to warm up cores +__global__ AICORE __attribute__((aic)) void warmup_kernel() +{} + +// Host wrapper +template +void LaunchTFA(uint16_t *ffts, aclFloat16 *q, aclFloat16 *k, aclFloat16 *v, aclFloat16 *p_tile_fifo, + float *exp_max_ififo, float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, + float *qk_tile_fifo, float *pv_tile_fifo, uint8_t *profile_data, aclrtStream stream, + uint8_t *cv_comm_buf) +{ + static_assert(S0 % CUBE_S0 == 0, "S0 must be divisible by CUBE_S0"); + constexpr uint32_t block_rows = S0 / CUBE_S0; + +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) + // Warm up all cores first, then prefetch q/k/v into L2 + warmup_kernel<<<24, nullptr, stream>>>(); + + const uint64_t tensor_elems = static_cast(S0) * static_cast(HEAD_SIZE); + const uint64_t tensor_bytes = tensor_elems * sizeof(half); + constexpr bool kPrefetchUseSdma = true; // simulation cannot use sdma + constexpr int kPrefetchAivCores = 40; // only used when kPrefetchUseSdma is false + + if constexpr (kPrefetchUseSdma) { + PTO_PREFETCH((__gm__ void *)q, tensor_bytes, stream); + PTO_PREFETCH((__gm__ void *)k, tensor_bytes, stream); + PTO_PREFETCH((__gm__ void *)v, tensor_bytes, stream); + } else { + PTO_PREFETCH((__gm__ void *)q, tensor_bytes, stream); + PTO_PREFETCH((__gm__ void *)k, tensor_bytes, stream); + PTO_PREFETCH((__gm__ void *)v, tensor_bytes, stream); + } +#endif + + runTFA<<>>( + (__gm__ uint64_t *)ffts, (half *)q, (half *)k, (half *)v, (half *)p_tile_fifo, exp_max_ififo, global_sum_out, + exp_max_out, o_out, o_parts_out, qk_tile_fifo, pv_tile_fifo, cv_comm_buf, profile_data); +} + +// Backward-compatible overload without profiling buffer +template +void LaunchTFA(uint16_t *ffts, aclFloat16 *q, aclFloat16 *k, aclFloat16 *v, aclFloat16 *p_tile_fifo, + float *exp_max_ififo, float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, + float *qk_tile_fifo, float *pv_tile_fifo, aclrtStream stream, uint8_t *cv_comm_buf) +{ + LaunchTFA(ffts, q, k, v, p_tile_fifo, exp_max_ififo, global_sum_out, exp_max_out, o_out, + o_parts_out, qk_tile_fifo, pv_tile_fifo, nullptr, stream, cv_comm_buf); +} + +#include "generated_cases.h" + +#define INSTANTIATE_TFA(S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, CAUSAL_MASK) \ + template void LaunchTFA( \ + uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \ + float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \ + uint8_t *profile_data, aclrtStream stream, uint8_t *cv_comm_buf); \ + template void LaunchTFA( \ + uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \ + float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \ + aclrtStream stream, uint8_t *cv_comm_buf); \ + template void LaunchTFA( \ + uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \ + float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \ + uint8_t *profile_data, aclrtStream stream, uint8_t *cv_comm_buf); \ + template void LaunchTFA( \ + uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \ + float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \ + aclrtStream stream, uint8_t *cv_comm_buf); + +TFA_FOR_EACH_CASE(INSTANTIATE_TFA) + +#undef INSTANTIATE_TFA \ No newline at end of file diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.h b/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.h new file mode 100644 index 00000000..6a5bf0eb --- /dev/null +++ b/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.h @@ -0,0 +1,43 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#ifndef FA_PERFORMANCE_KERNEL_H +#define FA_PERFORMANCE_KERNEL_H + +#include +#include +#include + +// Shared defaults for FA performance kernels and host driver +constexpr int kFaCvFifoSize = 8; +constexpr int kFaCvFifoConsSyncPeriod = kFaCvFifoSize / 2; +constexpr int kFaCubeS1 = 128; +constexpr int kFaTileS1 = 256; +constexpr int kFaQkPreload = 4; +constexpr std::size_t kFaProfileBytesPerBlock = 1024 * 3; // cube + two vec subblocks +constexpr std::size_t kFaCvCommSlotBytes = 512U; +constexpr int VEC_CORES = 2; // Default to 2 vector cores per cube + +template +void LaunchTFA(uint16_t *ffts, aclFloat16 *q, aclFloat16 *k, aclFloat16 *v, aclFloat16 *p_tile_fifo, + float *exp_max_ififo, float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, + float *qk_tile_fifo, float *pv_tile_fifo, uint8_t *profile_data, aclrtStream stream, + uint8_t *cv_comm_buf = nullptr); + +// Overload without profiling buffer. +template +void LaunchTFA(uint16_t *ffts, aclFloat16 *q, aclFloat16 *k, aclFloat16 *v, aclFloat16 *p_tile_fifo, + float *exp_max_ififo, float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, + float *qk_tile_fifo, float *pv_tile_fifo, aclrtStream stream, uint8_t *cv_comm_buf = nullptr); + +#endif // FA_PERFORMANCE_KERNEL_H \ No newline at end of file diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/pto_macro_fa_gu.hpp b/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/pto_macro_fa_gu.hpp new file mode 100644 index 00000000..3e4d0a9a --- /dev/null +++ b/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/pto_macro_fa_gu.hpp @@ -0,0 +1,59 @@ +/* +Copyright (c) 2025 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#ifndef PTO_MACRO_FA_GU_HPP +#define PTO_MACRO_FA_GU_HPP + +#include +#include +#include + +namespace pto { + +// ----------------------------------------------------------------------------- +// FlashAttention "GU" (running update) macro +// +// This implements the numerically-stable streaming update: +// O = O * exp(max_prev - max_new) + PV_tile +// and on the last tile: +// O = O / global_sum +// +// Performance notes: +// - Keep O resident in UB across tiles to avoid extra TLOAD/TSTORE. +// - exp_max and global_sum are per-row reduced tiles (shape [S0, 1]) that get broadcast over columns. +// ----------------------------------------------------------------------------- + +template +AICORE inline void pto_macro_fa_gu(svTileData __out__ prev_sv_tile, svTileData __in__ est_sv_tile, + reducedTileData __in__ exp_max) +{ + pto::TROWEXPANDMUL(prev_sv_tile, prev_sv_tile, exp_max); + pto::TADD(prev_sv_tile, prev_sv_tile, est_sv_tile); +} + +template +AICORE inline void pto_macro_fa_gu_last(svTileData __out__ prev_sv_tile, svTileData __in__ est_sv_tile, + reducedTileData __in__ exp_max, reducedTileData __in__ new_global_sum) +{ + pto::TROWEXPANDMUL(prev_sv_tile, prev_sv_tile, exp_max); + pto::TADD(prev_sv_tile, prev_sv_tile, est_sv_tile); + pto::TROWEXPANDDIV(prev_sv_tile, prev_sv_tile, new_global_sum); + // pto::TCVT(prev_sv_nd_tile, prev_sv_tile, RoundMode::CAST_RINT); +} + +template +AICORE inline void pto_macro_fa_gu_single_and_last_tile(svTileData __out__ sv_tile, + reducedTileData __in__ new_global_sum) +{ + pto::TROWEXPANDDIV(sv_tile, sv_tile, new_global_sum); +} + +} // namespace pto +#endif // TGU_PTO_H diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/pto_macro_fa_softmax.hpp b/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/pto_macro_fa_softmax.hpp new file mode 100644 index 00000000..685f29b4 --- /dev/null +++ b/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/pto_macro_fa_softmax.hpp @@ -0,0 +1,219 @@ +/* +Copyright (c) 2026 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#ifndef PTO_MACRO_FA_SOFTMAX_HPP +#define PTO_MACRO_FA_SOFTMAX_HPP + +#include + +namespace pto { + +// ----------------------------------------------------------------------------- +// FlashAttention streaming softmax (tile-level) +// +// Given one QK tile X (fp32), compute x_exp = exp(scale * (X - new_global_max)). +// This function maintains per-row running state (global_max, global_sum) so that we can +// stream over S1 tiles without materializing the full attention matrix. +// +// Performance notes: +// - Keep intermediate computations in fp32 for numerical stability. +// - The `init` specialization initializes running state for the first S1 tile. +// - The 2D->1D reshape for TCVT is used to avoid layout constraints and keep the cast fast. +// ----------------------------------------------------------------------------- + +constexpr PTO_INTERNAL float constexpr_sqrt(float x) +{ + if (x <= 0.0f) + return 0.0f; + float guess = x; + for (int i = 0; i < 8; ++i) { + guess = 0.5f * (guess + x / guess); + } + return guess; +} + +constexpr AICORE inline float constexpr_inv_sqrt(float x) +{ + return 1.0f / constexpr_sqrt(x); +} + +template +AICORE inline void softmax_opt_fa_init_impl(TileDataD2 __out__ x_exp, TileDataS1 __in__ input_x, + ReduceTileD1 __out__ local_max, ReduceTileD1 __out__ local_sum, + ReduceTileD1 __out__ new_global_max, ReduceTileD1 __out__ new_global_sum, + ReduceTileD1 __out__ exp_max, TileDataS1 __out__ tmp_float, + TileDataS1 __out__ p_tile_f32, TileDataS1 triu, int s0_index, int s1_index) +{ + (void)local_max; + (void)exp_max; + (void)local_sum; + + constexpr float scale = constexpr_inv_sqrt(HEAD_SIZE); + using Tile1D_fp32 = Tile; + using Tile1D_out = Tile; + Tile1D_fp32 p_tile_f32_1d; + Tile1D_out x_exp_1d; + if constexpr (CAUSAL_MASK) { + if (s0_index / TileDataS1::Cols == s1_index / TileDataS1::Cols) { + constexpr float negInf = -3.40282e+38; + TTRI(triu, 1 + (s0_index % TileDataS1::Cols)); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TMULS(triu, triu, negInf); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TADD(input_x, input_x, triu); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + } + } + // FA2.0 init mode + TROWMAX(new_global_max, input_x, tmp_float); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TROWEXPANDSUB(p_tile_f32, input_x, new_global_max); + TMULS(p_tile_f32, p_tile_f32, scale); + TEXP(p_tile_f32, p_tile_f32); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TROWSUM(new_global_sum, p_tile_f32, tmp_float); + + TRESHAPE(p_tile_f32_1d, p_tile_f32); + TRESHAPE(x_exp_1d, x_exp); + TCVT(x_exp_1d, p_tile_f32_1d, RoundMode::CAST_ROUND); +} + +template +AICORE inline void softmax_opt_fa_not_init_impl(TileDataD2 __out__ x_exp, TileDataS1 __in__ input_x, + ReduceTileD1 __out__ local_max, ReduceTileD1 __out__ local_sum, + ReduceTileD1 __out__ new_global_max, + ReduceTileD1 __out__ new_global_sum, ReduceTileD1 __out__ exp_max, + TileDataS1 __out__ tmp_float, TileDataS1 __out__ p_tile_f32, + TileDataS1 triu, int s0_index, int s1_index) +{ + constexpr float scale = constexpr_inv_sqrt(HEAD_SIZE); + + using ReduceTileD2 = Tile; + using Tile1D_fp32 = Tile; + using Tile1D_out = Tile; + + ReduceTileD2 tmp_shw_local_max; + ReduceTileD2 tmp_shw_new_global_max; + ReduceTileD2 tmp_shw_exp_max; + ReduceTileD2 tmp_shw_new_global_sum; + ReduceTileD2 tmp_shw_local_sum; + Tile1D_fp32 p_tile_f32_1d; + Tile1D_out x_exp_1d; + + if constexpr (CAUSAL_MASK) { + if (s0_index / TileDataS1::Cols == s1_index / TileDataS1::Cols) { + constexpr float negInf = -3.40282e+38; + TTRI(triu, 1 + (s0_index % TileDataS1::Cols)); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TMULS(triu, triu, negInf); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TADD(input_x, input_x, triu); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + } + } + // FA2.0 streaming mode (not first tile): update (global_max, global_sum) and rescale old sums. + TROWMAX(local_max, input_x, tmp_float); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TRESHAPE(tmp_shw_local_max, local_max); + TRESHAPE(tmp_shw_new_global_max, new_global_max); + TMAX(tmp_shw_local_max, tmp_shw_local_max, tmp_shw_new_global_max); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TRESHAPE(tmp_shw_exp_max, exp_max); + TSUB(tmp_shw_exp_max, tmp_shw_new_global_max, tmp_shw_local_max); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + + TMULS(tmp_shw_new_global_max, tmp_shw_local_max, 1.0f); // just copy +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TROWEXPANDSUB(p_tile_f32, input_x, local_max); + TMULS(tmp_shw_exp_max, tmp_shw_exp_max, scale); + TMULS(p_tile_f32, p_tile_f32, scale); + TEXP(tmp_shw_exp_max, tmp_shw_exp_max); + TRESHAPE(tmp_shw_exp_max, exp_max); + TEXP(p_tile_f32, p_tile_f32); + TRESHAPE(tmp_shw_exp_max, exp_max); + + TRESHAPE(p_tile_f32_1d, p_tile_f32); + TRESHAPE(x_exp_1d, x_exp); + TCVT(x_exp_1d, p_tile_f32_1d, RoundMode::CAST_ROUND); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TRESHAPE(tmp_shw_new_global_sum, new_global_sum); + TMUL(tmp_shw_new_global_sum, tmp_shw_exp_max, tmp_shw_new_global_sum); + TROWSUM(local_sum, p_tile_f32, tmp_float); + TRESHAPE(tmp_shw_local_sum, local_sum); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TADD(tmp_shw_new_global_sum, tmp_shw_new_global_sum, tmp_shw_local_sum); +} + +template +AICORE inline void pto_macro_fa_softmax(TileDataD2 __out__ x_exp, TileDataS1 __in__ input_x, + ReduceTileD1 __out__ local_max, ReduceTileD1 __out__ local_sum, + ReduceTileD1 __in__ new_global_max, ReduceTileD1 __out__ new_global_sum, + ReduceTileD1 __out__ exp_max, TileDataS1 __out__ input_reduce_tmp, + TileDataS1 __out__ p_tile_fp32, TileDataS1 triu, int s0_index, int s1_index) +{ + if (s1_index <= s0_index || !CAUSAL_MASK) { + if constexpr (init) { + softmax_opt_fa_init_impl( + x_exp, input_x, local_max, local_sum, new_global_max, new_global_sum, exp_max, input_reduce_tmp, + p_tile_fp32, triu, s0_index, s1_index); + } else { + softmax_opt_fa_not_init_impl( + x_exp, input_x, local_max, local_sum, new_global_max, new_global_sum, exp_max, input_reduce_tmp, + p_tile_fp32, triu, s0_index, s1_index); + } + } else if constexpr (CAUSAL_MASK) { + TMULS(x_exp, x_exp, 0.0); + TMULS(exp_max, exp_max, 0.0); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + TADDS(exp_max, exp_max, 1.0); +#if defined(__DAV_C220_VEC__) + pipe_barrier(PIPE_V); +#endif + } +} + +} // namespace pto + +#endif diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/pto_macro_matmul.hpp b/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/pto_macro_matmul.hpp new file mode 100644 index 00000000..69a60b63 --- /dev/null +++ b/examples/aot/flash_attention/cpp_ref/split_pipe/kernels/flash_atten/pto_macro_matmul.hpp @@ -0,0 +1,212 @@ +/* +Copyright (c) 2026 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#ifndef PTO_MACRO_MATMUL_HPP +#define PTO_MACRO_MATMUL_HPP + +#include +#include + +#define CUBE_K_256 256 +#define CUBE_K_128 128 +#define CUBE_K_64 64 +#define CUBE_K_SMALLEST 32 + +namespace pto { + +/** + * Layout type for matrix multiplication operations. + * First letter represents the layout of matrix A, second letter represents matrix B. + * N = Normal (Row-major), T = Transposed (Column-major) + */ +enum class layout_t +{ + NN, // Matrix A: Normal, Matrix B: Normal + NT, // Matrix A: Normal, Matrix B: Transposed + TN, // Matrix A: Transposed, Matrix B: Normal + TT, // Matrix A: Transposed, Matrix B: Transposed + NONE +}; + +enum class AccMode +{ + Init, // auto phase, first slice initializes, rest accumulate + Acc, // auto phase, all slices accumulate into existing C + InitPartialSum, // explicitly partial, first slice initializes + InitFinalSum, // explicitly final, first slice initializes + AccPartialSum, // explicitly partial, all slices accumulate + AccFinalSum, // explicitly final, all slices accumulate +}; + +#define L0A_BUF0 ((__ca__ half *)(__ca__ char *)0x0) +#define L0A_BUF1 ((__ca__ half *)(__ca__ char *)0x8000) +#define L0B_BUF0 ((__ca__ half *)(__ca__ char *)0x0) +#define L0B_BUF1 ((__ca__ half *)(__ca__ char *)0x8000) +#define L0C_BUF0 ((__ca__ half *)(__ca__ char *)0x0) +#define L0C_BUF1 ((__ca__ half *)(__ca__ char *)0x20000) + +#define LAST_LOOP(x, n) ((x) == ((n)-1)) +#define UNIT_FLAG_ENABLE(i, n) (LAST_LOOP(i, n) ? 3 : 2) + +AICORE inline uint64_t getPingPong(uint32_t flip) +{ + static uint64_t pingpong = 0; + if (flip) { + pingpong = 1 - pingpong; + } + return pingpong; +} + +// Memory constraints (L0 ping-pong is 32 KiB per buffer in this implementation). +// Tuning knob: if you change L0 layout or buffer addresses, re-check these constraints. +constexpr uint32_t MEM_BUFFER_SIZE_BYTES = 64 * 1024 / 2; // 64KB per buffer with pingpong (32KB) +constexpr uint32_t HALF_SIZE_BYTES = 2; // sizeof(half) = 2 bytes + +/** + * Calculate the largest Cube_K value that fits in the 64KB memory buffer. + * Checks if both Cube_M * Cube_K (left matrix) and Cube_K * Cube_N (right matrix) + * can fit within the 64KB buffer. + * + * @param Cube_M - The tile dimension M + * @param Cube_N - The tile dimension N + * @return - Largest Cube_K value (32, 64, 128, or 256) that fits in memory + */ +// Choose the largest Cube_K that fits both L0A (Cube_M x Cube_K) and L0B (Cube_K x Cube_N) +// so TMATMUL stays compute-dense while respecting L0 ping-pong capacity. +AICORE inline constexpr uint32_t calculateFittingCubeK(uint32_t Cube_M, uint32_t Cube_N) +{ + uint32_t bestCubeK = CUBE_K_SMALLEST; // Default to smallest value + + // Test candidates from largest to smallest to find the largest that fits + if (Cube_M * CUBE_K_256 * HALF_SIZE_BYTES <= MEM_BUFFER_SIZE_BYTES && + CUBE_K_256 * Cube_N * HALF_SIZE_BYTES <= MEM_BUFFER_SIZE_BYTES) { + bestCubeK = CUBE_K_256; + } else if (Cube_M * CUBE_K_128 * HALF_SIZE_BYTES <= MEM_BUFFER_SIZE_BYTES && + CUBE_K_128 * Cube_N * HALF_SIZE_BYTES <= MEM_BUFFER_SIZE_BYTES) { + bestCubeK = CUBE_K_128; + } else if (Cube_M * CUBE_K_64 * HALF_SIZE_BYTES <= MEM_BUFFER_SIZE_BYTES && + CUBE_K_64 * Cube_N * HALF_SIZE_BYTES <= MEM_BUFFER_SIZE_BYTES) { + bestCubeK = CUBE_K_64; + } + + return bestCubeK; +} + +// Deduce layout_t from SLayouts +template +AICORE inline constexpr layout_t deduce_layout() +{ + if constexpr (TileDataA::SFractal == SLayout::RowMajor && TileDataB::SFractal == SLayout::RowMajor) + return layout_t::NN; + if constexpr (TileDataA::SFractal == SLayout::RowMajor && TileDataB::SFractal == SLayout::ColMajor) + return layout_t::NT; + if constexpr (TileDataA::SFractal == SLayout::ColMajor && TileDataB::SFractal == SLayout::RowMajor) + return layout_t::TN; + if constexpr (TileDataA::SFractal == SLayout::ColMajor && TileDataB::SFractal == SLayout::ColMajor) + return layout_t::TT; + return layout_t::NONE; +} + +struct MatmulCallConfig { + bool useAcc; // true -> TMATMUL_UF_ACC, false -> TMATMUL_UF + AccPhase phase; // UF mapping +}; + +AICORE inline MatmulCallConfig resolve_acc_mode(AccMode mode, bool isFirstSlice, bool isLastSlice) +{ + switch (mode) { + case AccMode::Init: + return MatmulCallConfig{!isFirstSlice, AccPhase::Unknown}; + case AccMode::Acc: + return MatmulCallConfig{true, AccPhase::Unknown}; + case AccMode::InitPartialSum: + return MatmulCallConfig{!isFirstSlice, AccPhase::Partial}; + case AccMode::InitFinalSum: + return MatmulCallConfig{!isFirstSlice, AccPhase::Final}; + case AccMode::AccPartialSum: + return MatmulCallConfig{true, AccPhase::Partial}; + case AccMode::AccFinalSum: + return MatmulCallConfig{true, AccPhase::Final}; + } + return MatmulCallConfig{!isFirstSlice, AccPhase::Partial}; +} + +template +AICORE inline void pto_macro_matmul(TileDataA &aMatTile, TileDataB &bMatTile, TileDataC &cAccTile, + AccMode accMode = AccMode::Init) +{ + constexpr layout_t layout = deduce_layout(); + + static_assert(layout != layout_t::NONE, "Deduced layout is NONE, check tile SLayouts"); + // Assert that template LAYOUT matches deduced layout if LAYOUT is not NONE + if constexpr (LAYOUT != layout_t::NONE) { + static_assert(LAYOUT == layout, + "Layout mismatch: template LAYOUT does not match deduced layout from tile SLayouts. " + "Check SLayout of TileDataA and TileDataB."); + } + + // Ping-pong is used to overlap TEXTRACT (L1->L0) with TMATMUL on alternating buffers. + uint64_t pingpong = getPingPong(0); + const uint64_t Cube_K = + calculateFittingCubeK(Cube_M, Cube_N) > Tile_K ? Tile_K : calculateFittingCubeK(Cube_M, Cube_N); + const uint64_t kSegments = (uint64_t)(Tile_K / Cube_K); + for (uint64_t k = 0; k < kSegments; k++) { + using LeftTile = TileLeft; + LeftTile al0Tiles[2] = {LeftTile(), LeftTile()}; + using RightTile = TileRight; + RightTile bl0Tiles[2] = {RightTile(), RightTile()}; + + TASSIGN(al0Tiles[0], (uint64_t)L0A_BUF0); + TASSIGN(al0Tiles[1], (uint64_t)L0A_BUF1); + TASSIGN(bl0Tiles[0], (uint64_t)L0B_BUF0); + TASSIGN(bl0Tiles[1], (uint64_t)L0B_BUF1); + + // Wait until previous TMATMUL finishes using this L0 buffer before overwriting it via TEXTRACT. + wait_flag(PIPE_M, PIPE_MTE1, pingpong); + + if (layout == layout_t::NT) { + TASSIGN(aMatTile, (uint64_t)aMatTile.data() + k * Cube_K * Cube_M * sizeof(typename TileDataA::DType)); + TASSIGN(bMatTile, (uint64_t)bMatTile.data() + k * Cube_K * Cube_N * sizeof(typename TileDataB::DType)); + } + + // TEXTRACT slices the current Cube_K panel into L0A/L0B. + TEXTRACT(al0Tiles[pingpong], aMatTile, 0, 0); + TEXTRACT(bl0Tiles[pingpong], bMatTile, 0, 0); + + set_flag(PIPE_MTE1, PIPE_M, pingpong); + wait_flag(PIPE_MTE1, PIPE_M, pingpong); + + const bool isLast = (k + 1 == kSegments); + MatmulCallConfig cfg = resolve_acc_mode(accMode, k == 0, isLast); + if (cfg.useAcc) { + if (cfg.phase == AccPhase::Final) { + TMATMUL_ACC(cAccTile, al0Tiles[pingpong], bl0Tiles[pingpong]); + } else if (cfg.phase == AccPhase::Partial) { + TMATMUL_ACC(cAccTile, al0Tiles[pingpong], bl0Tiles[pingpong]); + } else { + TMATMUL_ACC(cAccTile, al0Tiles[pingpong], bl0Tiles[pingpong]); + } + } else { + if (cfg.phase == AccPhase::Final) { + TMATMUL(cAccTile, al0Tiles[pingpong], bl0Tiles[pingpong]); + } else if (cfg.phase == AccPhase::Partial) { + TMATMUL(cAccTile, al0Tiles[pingpong], bl0Tiles[pingpong]); + } else { + TMATMUL(cAccTile, al0Tiles[pingpong], bl0Tiles[pingpong]); + } + } + set_flag(PIPE_M, PIPE_MTE1, pingpong); + pingpong = getPingPong(1); + } +} +} // namespace pto + +#endif // PTO_MACRO_MATMUL_H diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/results/fa_splitpipe_tile256.png b/examples/aot/flash_attention/cpp_ref/split_pipe/results/fa_splitpipe_tile256.png new file mode 100644 index 00000000..e3659c94 Binary files /dev/null and b/examples/aot/flash_attention/cpp_ref/split_pipe/results/fa_splitpipe_tile256.png differ diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/results/fa_splitpipe_tile512.png b/examples/aot/flash_attention/cpp_ref/split_pipe/results/fa_splitpipe_tile512.png new file mode 100644 index 00000000..28a3e296 Binary files /dev/null and b/examples/aot/flash_attention/cpp_ref/split_pipe/results/fa_splitpipe_tile512.png differ diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/run.py b/examples/aot/flash_attention/cpp_ref/split_pipe/run.py new file mode 100644 index 00000000..f406f6f7 --- /dev/null +++ b/examples/aot/flash_attention/cpp_ref/split_pipe/run.py @@ -0,0 +1,207 @@ +#!/usr/bin/python3 +# coding=utf-8 +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the +# terms and conditions of CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance +# with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, +# OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# -------------------------------------------------------------------------------- + +# Split-pipe FA: JIT-compiles bundled kernels/flash_atten/fa_performance_kernel.cpp (TileSplitAxis, etc.) +# plus call_kernel_dispatch.cpp. Template instantiations are listed in generated_cases.h (regenerate via scripts/generate_cases.py). + +import random +import math +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import torch +import torch_npu +from jit_util_flash import jit_compile_flash +from ptodsl.utils import get_test_device +from ptodsl.bench import do_bench + +_DEVICE = get_test_device() +torch.npu.set_device(_DEVICE) + +NUM_ITERATIONS = 15 +WARMUP = 10 +SEED = 1 + +random.seed(SEED) +torch.manual_seed(SEED) +torch.npu.manual_seed(SEED) + + +def attn_flops_matmul_softmax_scale( + batch_size: int, + s_q: int, + s_k: int, + h: int, + include_scale: bool = True, + count_exp_as_flop: bool = True, + count_max_as_flop: bool = True, +): + flops_matmul = 4 * batch_size * s_q * s_k * h + flops_scale = (batch_size * s_q * s_k) if include_scale else 0 + + rows = batch_size * s_q + softmax_ops = 0 + if count_max_as_flop: + softmax_ops += rows * (s_k - 1) + softmax_ops += rows * s_k + if count_exp_as_flop: + softmax_ops += rows * s_k + softmax_ops += rows * (s_k - 1) + softmax_ops += rows * s_k + + return flops_matmul + flops_scale + softmax_ops + + +def tflops(flops: int, ms: float) -> float: + return flops / (ms * 1e-3) / 1e12 + + +def fa_reference(q, k, v, is_causal=False): + scale = 1.0 / math.sqrt(q.shape[1]) + scores = q.float() @ k.float().T * scale + if is_causal: + mask = torch.triu( + torch.ones(scores.shape, device=q.device, dtype=torch.bool), diagonal=1 + ) + scores = scores.masked_fill(mask, float("-inf")) + attn = torch.softmax(scores, dim=-1) + return attn @ v.float() + + +def fused_attention(q, k, v, is_causal=False): + scale = 1.0 / math.sqrt(q.shape[1]) + out, _ = torch_npu.npu_fused_infer_attention_score( + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + num_heads=1, + input_layout="BSH", + scale=scale, + next_tokens=0 if is_causal else 65535, + ) + return out.squeeze(0) + + +def test_flash(tile_s1: int = 512, head: int = 128): + if head != 128: + raise ValueError( + "split_pipe generated_cases.h currently instantiates HEAD_SIZE=128 only; " + "regenerate generated_cases.h with more heads if needed." + ) + + s0 = 128 * 24 + s1_values = [1024, 2048, 4096, 8192, 16384, 32768, 64 * 1024, 128 * 1024] + bad_s1 = [s1 for s1 in s1_values if s1 % tile_s1 != 0] + if bad_s1: + raise ValueError(f"tile_s1={tile_s1} does not divide S1 values: {bad_s1}") + + dtype = torch.float16 + q2d = torch.randn((s0, head), dtype=dtype).npu() + flash = jit_compile_flash(verbose=False) + + flash_ms_values = [] + npu_ms_values = [] + ref_ms_values = [] + flash_tflops_values = [] + npu_tflops_values = [] + ref_tflops_values = [] + + for s1 in s1_values: + flops_total = attn_flops_matmul_softmax_scale(1, s0, s1, head) + + k2d = torch.randn((s1, head), dtype=dtype).npu() + v2d = torch.randn((s1, head), dtype=dtype).npu() + + # Custom bisheng FA kernels do not tolerate do_bench's default 256MB L2 flush between iterations. + ref_ms = do_bench( + lambda: fa_reference(q2d, k2d, v2d), + warmup_iters=WARMUP, + benchmark_iters=NUM_ITERATIONS, + unit="ms", + flush_cache=False, + ) + npu_ms = do_bench( + lambda: fused_attention(q2d, k2d, v2d), + warmup_iters=WARMUP, + benchmark_iters=NUM_ITERATIONS, + unit="ms", + flush_cache=False, + ) + flash_ms = do_bench( + lambda: flash(q2d, k2d, v2d, tile_s1=tile_s1), + warmup_iters=WARMUP, + benchmark_iters=NUM_ITERATIONS, + unit="ms", + flush_cache=False, + ) + + flash_ms_values.append(flash_ms) + npu_ms_values.append(npu_ms) + ref_ms_values.append(ref_ms) + flash_tflops_values.append(tflops(flops_total, flash_ms)) + npu_tflops_values.append(tflops(flops_total, npu_ms)) + ref_tflops_values.append(tflops(flops_total, ref_ms)) + + o_out = flash(q2d, k2d, v2d, tile_s1=tile_s1) + o_ref = fa_reference(q2d, k2d, v2d).to(torch.float32) + o_npu = fused_attention(q2d, k2d, v2d).to(torch.float32) + + print(f"S1 : {s1}") + print(f"Tile S1 : {tile_s1}") + print(f"FLOPs total : {flops_total}") + print( + f"JIT flash kernel : {flash_ms:.3f} ms/iter " + f"({tflops(flops_total, flash_ms):.3f} TFLOP/s)" + ) + print( + f"npu_fused_infer_attention : {npu_ms:.3f} ms/iter " + f"({tflops(flops_total, npu_ms):.3f} TFLOP/s)" + ) + print( + f"torch reference : {ref_ms:.3f} ms/iter " + f"({tflops(flops_total, ref_ms):.3f} TFLOP/s)" + ) + torch.testing.assert_close(o_out, o_ref, rtol=1e-3, atol=1e-3) + print("vs torch reference: PASSED") + torch.testing.assert_close(o_out, o_npu, rtol=1e-3, atol=1e-3) + print("vs npu_fused_attention: PASSED") + print("") + + plot_path = Path(__file__).with_name("fa_split_pipe_s1_plot.png") + plt.figure(figsize=(8, 5)) + plt.plot(s1_values, flash_tflops_values, marker="o", label="flash split_pipe") + plt.plot(s1_values, ref_tflops_values, marker="o", label="ref") + plt.plot(s1_values, npu_tflops_values, marker="o", label="torch_npu") + plt.xscale("log", base=2) + plt.xticks(s1_values, [str(v) for v in s1_values]) + plt.xlabel("S1") + plt.ylabel("TFLOP/s") + plt.title( + f"Split-pipe FA TFLOP/s vs S1 (S0={s0}, head={head}, tile_s1={tile_s1})" + ) + plt.grid(True, which="both", axis="both", linestyle="--", linewidth=0.5) + plt.legend() + plt.tight_layout() + plt.savefig(plot_path, dpi=160) + plt.close() + print(f"Saved plot to {plot_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--tile-s1", type=int, choices=(256, 512, 1024), default=512) + parser.add_argument("--head", type=int, choices=(128,), default=128) + args = parser.parse_args() + test_flash(tile_s1=args.tile_s1, head=args.head) diff --git a/examples/aot/flash_attention/cpp_ref/split_pipe/scripts/generate_cases.py b/examples/aot/flash_attention/cpp_ref/split_pipe/scripts/generate_cases.py new file mode 100644 index 00000000..591eb207 --- /dev/null +++ b/examples/aot/flash_attention/cpp_ref/split_pipe/scripts/generate_cases.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# -------------------------------------------------------------------------------- +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# -------------------------------------------------------------------------------- + +""" +Generate TFA case configuration and emit a shared header/JSON for host/kernel build. + +Usage examples: + python3 generate_cases.py --cases "128,128,1024,128,256" \ + --cases "128,512,2048,128,256" + + # Override cube-side preload depth (defaults to 4) + python3 generate_cases.py --qk-preload 6 + +Each --cases entry format: HEAD_SIZE,S0,S1,CUBE_S0[,TILE_S1] +CUBE_S1 is fixed at 128; TILE_S1 defaults to 256 if omitted. +Defaults replicate the previous hard-coded set if --cases is omitted. +""" +import argparse +import json +import os +from pathlib import Path +from typing import List, Dict + +TILE_S1_DEFAULT = 256 +QK_PRELOAD_DEFAULT = 4 + +DEFAULT_CASES = [ + (128, 128, 1024, 128, TILE_S1_DEFAULT, False), + (128, 128, 2048, 128, TILE_S1_DEFAULT, False), + (128, 128, 8192, 128, TILE_S1_DEFAULT, False), + (128, 512, 1024, 128, TILE_S1_DEFAULT, False), + (128, 512, 2048, 128, TILE_S1_DEFAULT, False), + (128, 512, 8192, 128, TILE_S1_DEFAULT, False), +] + + +def _parse_case_entry(raw: str, qk_preload: int, causal_mask: bool) -> Dict[str, int]: + parts = [p.strip() for p in raw.split(',') if p.strip()] + if len(parts) not in (4, 5): + raise ValueError(f"Expected 4 or 5 comma-separated values (HEAD_SIZE,S0,S1,CUBE_S0[,TILE_S1]), got '{raw}'") + head, s0, s1, cube_s0 = map(int, parts[:4]) + tile_s1 = int(parts[4]) if len(parts) == 5 else TILE_S1_DEFAULT + return { + "head_size": head, + "s0": s0, + "s1": s1, + "cube_s0": cube_s0, + "cube_s1": 128, + "tile_s1": tile_s1, + "qk_preload": qk_preload, + "causal_mask": int(causal_mask), + } + + +def _default_cases(qk_preload: int) -> List[Dict[str, int]]: + return [ + { + "head_size": head, + "s0": s0, + "s1": s1, + "cube_s0": cube_s0, + "cube_s1": 128, + "tile_s1": tile_s1, + "qk_preload": qk_preload, + "causal_mask": int(causal_mask), + } + for (head, s0, s1, cube_s0, tile_s1, causal_mask) in DEFAULT_CASES + ] + + +def _case_name(case: Dict[str, int]) -> str: + return f"case_float_H_{case['head_size']}_S0_{case['s0']}_S1_{case['s1']}" + + +def _normalize_case(case: Dict[str, int]) -> Dict[str, int]: + if case["qk_preload"] < 1: + raise ValueError("qk_preload must be >= 1") + + # Ensure cube_s0 does not exceed s0 and divides evenly; otherwise set cube_s0 = s0 + if case["cube_s0"] > case["s0"] or case["s0"] % case["cube_s0"] != 0: + case["cube_s0"] = case["s0"] + + # Fix cube_s1 to 128 and ensure divisibility + if case["cube_s1"] != 128: + case["cube_s1"] = 128 + if case["s1"] % case["cube_s1"] != 0: + raise ValueError("S1 must be divisible by CUBE_S1 (128)") + + # Ensure TILE_S1 divides S1 and is a multiple of CUBE_S1 + if case["tile_s1"] % case["cube_s1"] != 0: + raise ValueError("TILE_S1 must be divisible by CUBE_S1 (128)") + if case["s1"] % case["tile_s1"] != 0: + raise ValueError("S1 must be divisible by TILE_S1") + + return case + + +def _render_macro(cases: List[Dict[str, int]]) -> str: + lines = ["#define TFA_FOR_EACH_CASE(MACRO) \\"] + for idx, case in enumerate(cases): + causal_mask = str("true" if bool(case["causal_mask"]) else "false") + suffix = " \\" if idx + 1 != len(cases) else "" + line = f" MACRO({case['s0']}, {case['head_size']}, {case['s1']}, {case['cube_s0']}, {case['cube_s1']}, {case['tile_s1']}, {case['qk_preload']}, {causal_mask}){suffix}" + lines.append(line) + return "\n".join(lines) + + +def _render_header(cases: List[Dict[str, int]]) -> str: + macro_block = _render_macro(cases) + array_entries = [] + for case in cases: + array_entries.append( + " {" + ", ".join( + [ + str(case["s0"]), + str(case["head_size"]), + str(case["s1"]), + str(case["cube_s0"]), + str(case["cube_s1"]), + str(case["tile_s1"]), + str(case["qk_preload"]), + str("true" if bool(case["causal_mask"]) else "false"), + f'"{_case_name(case)}"', + ] + ) + "}" + ) + array_block = ",\n".join(array_entries) + + return f"""#pragma once +// Auto-generated by scripts/generate_cases.py. Do not edit manually. +// clang-format off +#include + +{macro_block} + +struct GeneratedTfaCase {{ + int s0; + int head_size; + int s1; + int cube_s0; + int cube_s1; + int tile_s1; + int qk_preload; + bool causal_mask; + const char *name; +}}; + +static constexpr GeneratedTfaCase kGeneratedTfaCases[] = {{ +{array_block} +}}; +static constexpr std::size_t kGeneratedTfaCasesCount = sizeof(kGeneratedTfaCases) / sizeof(kGeneratedTfaCases[0]); +// clang-format on +""" + + +_SPLIT_PIPE_DIR = Path(__file__).resolve().parent.parent + + +def main() -> None: + parser = argparse.ArgumentParser(description="Generate TFA case header/JSON") + parser.add_argument( + "--cases", + action="append", + default=None, + help="Case entry in the format HEAD_SIZE,S0,S1,CUBE_S0[,TILE_S1] (repeat for multiple entries; CUBE_S1 fixed at 128)", + ) + parser.add_argument( + "--qk-preload", + type=int, + default=QK_PRELOAD_DEFAULT, + help="qkPreloadNum (cube pipeline preload depth) applied to all generated cases", + ) + parser.add_argument( + "--output-header", + default=str(_SPLIT_PIPE_DIR / "generated_cases.h"), + help="Output header path (default: split_pipe/generated_cases.h)", + ) + parser.add_argument( + "--output-json", + default=str(_SPLIT_PIPE_DIR / "build" / "generated_cases.json"), + help="Output JSON path (default: split_pipe/build/generated_cases.json)", + ) + parser.add_argument( + "--causal-mask", + default=False, + help="Enable causal mask", + ) + args = parser.parse_args() + + if args.cases: + cases = [_normalize_case(_parse_case_entry(entry, args.qk_preload, args.causal_mask)) for entry in args.cases] + else: + cases = [_normalize_case(case) for case in _default_cases(args.qk_preload)] + + header_text = _render_header(cases) + header_path = Path(args.output_header) + header_path.parent.mkdir(parents=True, exist_ok=True) + header_path.write_text(header_text) + + json_payload = [ + { + "name": _case_name(case), + **case, + } + for case in cases + ] + json_path = Path(args.output_json) + json_path.parent.mkdir(parents=True, exist_ok=True) + json_path.write_text(json.dumps(json_payload, indent=2)) + + print(f"[INFO] Wrote {header_path}") + print(f"[INFO] Wrote {json_path}") + print("[INFO] Cases generated:") + for case in json_payload: + print(f" - {case['name']} (H={case['head_size']}, S0={case['s0']}, S1={case['s1']}, CUBE_S0={case['cube_s0']}, CUBE_S1={case['cube_s1']}, TILE_S1={case['tile_s1']}, QK_PRELOAD={case['qk_preload']}, CAUSAL_MASK={case['causal_mask']})") + + +if __name__ == "__main__": + main() diff --git a/examples/aot/flash_attention/experimental/caller.cpp b/examples/aot/flash_attention/experimental/caller.cpp deleted file mode 120000 index 5fe1d184..00000000 --- a/examples/aot/flash_attention/experimental/caller.cpp +++ /dev/null @@ -1 +0,0 @@ -../caller.cpp \ No newline at end of file diff --git a/examples/aot/flash_attention/experimental/caller.cpp b/examples/aot/flash_attention/experimental/caller.cpp new file mode 100644 index 00000000..87546451 --- /dev/null +++ b/examples/aot/flash_attention/experimental/caller.cpp @@ -0,0 +1,32 @@ +#ifndef KERNEL_CPP +#error "KERNEL_CPP must be defined at compile time." +#endif + +#include + +extern "C" int rtGetC2cCtrlAddr(uint64_t *ctrlAddr, uint32_t *ctrlLen); + +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, + void *stream, + uint8_t *gmSlotBuffer, + uint8_t *q, + uint8_t *k, // K: [S1_TOTAL, HEAD] fp16 + uint8_t *v, + uint8_t *o) // output O: [Q_ROWS, HEAD] fp32 +{ + void *fftsAddr = nullptr; + uint32_t fftsLen = 0; + (void)rtGetC2cCtrlAddr(reinterpret_cast(&fftsAddr), &fftsLen); + (void)fftsLen; + + call_both<<>>( + (__gm__ int64_t *)fftsAddr, + (__gm__ float *)gmSlotBuffer, + (__gm__ half *)q, + (__gm__ half *)k, + (__gm__ half *)v, + (__gm__ float *)o); +} diff --git a/examples/aot/flash_attention/README.md b/examples/aot/flash_attention/head32/README.md similarity index 100% rename from examples/aot/flash_attention/README.md rename to examples/aot/flash_attention/head32/README.md diff --git a/examples/aot/flash_attention/caller.cpp b/examples/aot/flash_attention/head32/caller.cpp similarity index 100% rename from examples/aot/flash_attention/caller.cpp rename to examples/aot/flash_attention/head32/caller.cpp diff --git a/examples/aot/flash_attention/compile.sh b/examples/aot/flash_attention/head32/compile.sh similarity index 100% rename from examples/aot/flash_attention/compile.sh rename to examples/aot/flash_attention/head32/compile.sh diff --git a/examples/aot/flash_attention/fa_builder.py b/examples/aot/flash_attention/head32/fa_builder.py similarity index 100% rename from examples/aot/flash_attention/fa_builder.py rename to examples/aot/flash_attention/head32/fa_builder.py diff --git a/examples/aot/flash_attention/run.py b/examples/aot/flash_attention/head32/run.py similarity index 100% rename from examples/aot/flash_attention/run.py rename to examples/aot/flash_attention/head32/run.py diff --git a/examples/aot/flash_attention/ir_ref/README.md b/examples/aot/flash_attention/ir_ref/README.md new file mode 100644 index 00000000..d677ab7c --- /dev/null +++ b/examples/aot/flash_attention/ir_ref/README.md @@ -0,0 +1,6 @@ +Copy IR from https://github.com/hw-native-sys/PTOAS/pull/609 +to reverse engineer python frontend + +```bash +bash ./gen_cpp.sh +``` diff --git a/examples/aot/flash_attention/ir_ref/fa.cpp b/examples/aot/flash_attention/ir_ref/fa.cpp new file mode 100644 index 00000000..5210838c --- /dev/null +++ b/examples/aot/flash_attention/ir_ref/fa.cpp @@ -0,0 +1,856 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +template +static AICORE inline auto PTOAS__GLOBAL_TENSOR_DATA(Tensor &tensor) + -> decltype(tensor.data()) { + return tensor.data(); +} + + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +AICORE void cube_kernel(__gm__ float* v1, __gm__ half* v2, __gm__ half* v3, __gm__ half* v4) { + unsigned v5 = 3840; + unsigned v6 = 384; + unsigned v7 = 256; + unsigned v8 = 128; + unsigned v9 = 16; + unsigned v10 = 0; + const int32_t v11 = 16; + const int32_t v12 = 256; + const int32_t v13 = 128; + const int32_t v14 = 1; + const int32_t v15 = 0; + const int32_t v16 = 524288; + const int32_t v17 = 262144; + const int64_t v18 = 0; + const int64_t v19 = 32768; + const int64_t v20 = 65536; + const int64_t v21 = 163840; + const int64_t v22 = 131072; + const int32_t v23 = 2; + const int32_t v24 = 7; + const int32_t v25 = 393216; + using T = float; + + #if defined(__DAV_CUBE__) + size_t v26 = (size_t) v14; + int64_t v27 = get_block_num(); + int32_t v28 = (int32_t) ((int64_t) v27); + int64_t v29 = get_block_idx(); + int32_t v30 = (int32_t) ((int64_t) v29); + int32_t v31 = v11 / v28; + int32_t v32 = v11 % v28; + int32_t v33 = (int32_t) ((uint32_t) v31 + (uint32_t) v14); + bool v34 = v30 < v32; + int32_t v35 = v34 ? (int32_t) ((uint32_t) v30 * (uint32_t) v33) : (int32_t) ((uint32_t) ((int32_t) (uint32_t) v32 * (uint32_t) v33) + (uint32_t) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v30 - (uint32_t) v32) * (uint32_t) v31)); + int32_t v36 = (int32_t) ((uint32_t) v30 * (uint32_t) v16); + int32_t v37 = (int32_t) ((uint32_t) v36 + (uint32_t) v25); + __gm__ float* v38 = v1 + v37; + + __gm__ float* v39 = v1 + v36; + + auto v40 = TPipe<0, Direction::DIR_C2V, 131072, 8, 8, true>(v39, v15, v15); + int32_t v41 = (int32_t) ((uint32_t) v36 + (uint32_t) v17); + __gm__ float* v42 = v1 + v41; + + auto v43 = TPipe<2, Direction::DIR_C2V, 65536, 8, 8, true>(v42, v15, v15); + auto v44 = TPipe<4, Direction::DIR_V2C, 65536, 8, 8, false>(v38, v15, v25); + Tile v45; + TASSIGN(v45, v18); + Tile v46; + TASSIGN(v46, v18); + Tile v47; + TASSIGN(v47, v19); + Tile v48; + TASSIGN(v48, v20); + Tile v49; + TASSIGN(v49, v18); + Tile v50; + TASSIGN(v50, v18); + Tile v51; + TASSIGN(v51, v19); + Tile v52; + TASSIGN(v52, v21); + Tile v53; + TASSIGN(v53, v18); + Tile v54; + TASSIGN(v54, v22); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + for (size_t v55 = (size_t) v35; v55 < ((size_t) ((int32_t) (uint32_t) v35 + (uint32_t) (v34 ? v33 : v31))); v55 += v26) { + pto::Shape<1, 1, 1, 128, 128> v56 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v57 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v58 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v2 + (v10 + (unsigned) ((int32_t) (uint32_t) ((int32_t) v55) * (uint32_t) v13) * (unsigned) v13 + v10 * (unsigned) v14), v56, v57); + TLOAD(v45, v58); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v46, v45); + pto::Shape<1, 1, 1, 128, 128> v59 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v60 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v61 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v10 + v10 * (unsigned) v14 + v10 * (unsigned) v13), v59, v60); + TLOAD(v47, v61); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TMOV(v49, v47); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + __cc__ float* v62 = v50.data(); + __cc__ float* v63 = v62 + (v10 + v10 * v9 + v10 * v8); + __cc__ float* v64 = (__cc__ float*) v63; + Tile v65; + uint64_t v66 = reinterpret_cast(v63); + TASSIGN(v65, v66); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(v65, v46, v49); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 128, 128> v67 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v68 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v69 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v10 + v10 * (unsigned) v14 + v8 * (unsigned) v13), v67, v68); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v47, v69); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v49, v47); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + __cc__ float* v70 = v50.data(); + __cc__ float* v71 = v70 + (v10 + v10 * v9 + v8 * v8); + __cc__ float* v72 = (__cc__ float*) v71; + Tile v73; + uint64_t v74 = reinterpret_cast(v71); + TASSIGN(v73, v74); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + TMATMUL(v73, v46, v49); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v75(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v40, v75); + __gm__ float* v76 = PTOAS__GLOBAL_TENSOR_DATA(v75); + pto::Shape<1, 1, 1, 128, 256> v77 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<32768, 32768, 32768, 256, 1> v78 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v79 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v76, v77, v78); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v79, v50); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v40, v75); + pto::Shape<1, 1, 1, 128, 128> v80 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v81 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v82 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v10 + v10 * (unsigned) v14 + v7 * (unsigned) v13), v80, v81); + TLOAD(v48, v82); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + TMOV(v49, v48); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + TMATMUL(v65, v46, v49); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + pto::Shape<1, 1, 1, 128, 128> v83 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v84 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v85 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v10 + v10 * (unsigned) v14 + v6 * (unsigned) v13), v83, v84); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v48, v85); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + TMOV(v49, v48); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + TMATMUL(v73, v46, v49); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v86(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v40, v86); + __gm__ float* v87 = PTOAS__GLOBAL_TENSOR_DATA(v86); + pto::Shape<1, 1, 1, 128, 256> v88 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<32768, 32768, 32768, 256, 1> v89 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v90 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v87, v88, v89); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + TSTORE(v90, v50); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v40, v86); + pto::Shape<1, 1, 1, 256, 128> v91 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<32768, 32768, 32768, 128, 1> v92 = pto::Stride<32768, 32768, 32768, 128, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 128, 1>, pto::Layout::ND> v93 = GlobalTensor, pto::Stride<32768, 32768, 32768, 128, 1>, pto::Layout::ND>(v4 + (v10 + v10 * (unsigned) v13 + v10 * (unsigned) v14), v91, v92); + TLOAD(v52, v93); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + for (size_t v94 = (size_t) v15; v94 < ((size_t) v24); v94 += v26) { + int32_t v95 = (int32_t) ((uint32_t) ((int32_t) v94) * (uint32_t) v23); + int32_t v96 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v95 + (uint32_t) v23) * (uint32_t) v12); + Tile v97; + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v44, v97); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID5); + TMOV(v51, v97); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v44); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + TMOV(v53, v52); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + int32_t v98 = (int32_t) ((uint32_t) v95 + (uint32_t) v14); + pto::Shape<1, 1, 1, 256, 128> v99 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<32768, 32768, 32768, 128, 1> v100 = pto::Stride<32768, 32768, 32768, 128, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 128, 1>, pto::Layout::ND> v101 = GlobalTensor, pto::Stride<32768, 32768, 32768, 128, 1>, pto::Layout::ND>(v4 + (v10 + (unsigned) ((int32_t) (uint32_t) v98 * (uint32_t) v12) * (unsigned) v13 + v10 * (unsigned) v14), v99, v100); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v52, v101); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + TMATMUL(v54, v51, v53); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v102(nullptr); + TALLOC, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v102); + __gm__ float* v103 = PTOAS__GLOBAL_TENSOR_DATA(v102); + pto::Shape<1, 1, 1, 128, 128> v104 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v105 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v106 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v103, v104, v105); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + TSTORE(v106, v54); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + TPUSH, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v102); + pto::Shape<1, 1, 1, 128, 128> v107 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v108 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v109 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v10 + v10 * (unsigned) v14 + (unsigned) v96 * (unsigned) v13), v107, v108); + TLOAD(v47, v109); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TMOV(v49, v47); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + TMATMUL(v65, v46, v49); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + pto::Shape<1, 1, 1, 128, 128> v110 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v111 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v112 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v10 + v10 * (unsigned) v14 + (unsigned) ((int32_t) (uint32_t) v96 + (uint32_t) v13) * (unsigned) v13), v110, v111); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v47, v112); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TMOV(v49, v47); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + TMATMUL(v73, v46, v49); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v113(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v40, v113); + __gm__ float* v114 = PTOAS__GLOBAL_TENSOR_DATA(v113); + pto::Shape<1, 1, 1, 128, 256> v115 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<32768, 32768, 32768, 256, 1> v116 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v117 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v114, v115, v116); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + TSTORE(v117, v50); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v40, v113); + int32_t v118 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v98 + (uint32_t) v23) * (uint32_t) v12); + Tile v119; + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v44, v119); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v51, v119); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v44); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v53, v52); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + pto::Shape<1, 1, 1, 256, 128> v120 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<32768, 32768, 32768, 128, 1> v121 = pto::Stride<32768, 32768, 32768, 128, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 128, 1>, pto::Layout::ND> v122 = GlobalTensor, pto::Stride<32768, 32768, 32768, 128, 1>, pto::Layout::ND>(v4 + (v10 + (unsigned) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v98 + (uint32_t) v14) * (uint32_t) v12) * (unsigned) v13 + v10 * (unsigned) v14), v120, v121); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + TLOAD(v52, v122); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + TMATMUL(v54, v51, v53); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v123(nullptr); + TALLOC, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v123); + __gm__ float* v124 = PTOAS__GLOBAL_TENSOR_DATA(v123); + pto::Shape<1, 1, 1, 128, 128> v125 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v126 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v127 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v124, v125, v126); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + TSTORE(v127, v54); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + TPUSH, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v123); + pto::Shape<1, 1, 1, 128, 128> v128 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v129 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v130 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v10 + v10 * (unsigned) v14 + (unsigned) v118 * (unsigned) v13), v128, v129); + TLOAD(v48, v130); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v49, v48); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + TMATMUL(v65, v46, v49); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 128, 128> v131 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v132 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v133 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v10 + v10 * (unsigned) v14 + (unsigned) ((int32_t) (uint32_t) v118 + (uint32_t) v13) * (unsigned) v13), v131, v132); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + TLOAD(v48, v133); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v49, v48); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(v73, v46, v49); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v134(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v40, v134); + __gm__ float* v135 = PTOAS__GLOBAL_TENSOR_DATA(v134); + pto::Shape<1, 1, 1, 128, 256> v136 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<32768, 32768, 32768, 256, 1> v137 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v138 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v135, v136, v137); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + TSTORE(v138, v50); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v40, v134); + }; + set_flag(PIPE_FIX, PIPE_M, EVENT_ID6); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + Tile v139; + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v44, v139); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v51, v139); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v44); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v53, v52); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + pto::Shape<1, 1, 1, 256, 128> v140 = pto::Shape<1, 1, 1, 256, 128>(); + pto::Stride<32768, 32768, 32768, 128, 1> v141 = pto::Stride<32768, 32768, 32768, 128, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 128, 1>, pto::Layout::ND> v142 = GlobalTensor, pto::Stride<32768, 32768, 32768, 128, 1>, pto::Layout::ND>(v4 + (v10 + v5 * (unsigned) v13 + v10 * (unsigned) v14), v140, v141); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + TLOAD(v52, v142); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID6); + TMATMUL(v54, v51, v53); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID6); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v143(nullptr); + TALLOC, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v143); + __gm__ float* v144 = PTOAS__GLOBAL_TENSOR_DATA(v143); + pto::Shape<1, 1, 1, 128, 128> v145 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v146 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v147 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v144, v145, v146); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID6); + TSTORE(v147, v54); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID7); + TPUSH, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v143); + Tile v148; + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v44, v148); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v51, v148); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v44); + TMOV(v53, v52); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID7); + TMATMUL(v54, v51, v53); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID7); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v149(nullptr); + TALLOC, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v149); + __gm__ float* v150 = PTOAS__GLOBAL_TENSOR_DATA(v149); + pto::Shape<1, 1, 1, 128, 128> v151 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v152 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v153 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v150, v151, v152); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID7); + TSTORE(v153, v54); + TPUSH, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v149); + } + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +AICORE void vector_kernel(__gm__ float* v1, __gm__ float* v2) { + unsigned v3 = 128; + RoundMode v4 = RoundMode::CAST_RINT; + unsigned v5 = 256; + unsigned v6 = 0; + const int32_t v7 = 0; + const int32_t v8 = 16; + const int32_t v9 = 64; + const int32_t v10 = 128; + const int32_t v11 = 1; + const int32_t v12 = 524288; + const int32_t v13 = 262144; + const int64_t v14 = 196608; + const int64_t v15 = 262144; + const int64_t v16 = 327680; + const int64_t v17 = 360448; + const int64_t v18 = 393216; + const int64_t v19 = 393472; + const int64_t v20 = 393728; + const int64_t v21 = 393984; + const int64_t v22 = 394240; + const int64_t v23 = 394496; + const float v24 = 0.0883883461f; + const float v25 = 1.0f; + const int64_t v26 = 394752; + const int32_t v27 = 7; + const int32_t v28 = 393216; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + size_t v29 = (size_t) v11; + int64_t v30 = get_block_num(); + int32_t v31 = (int32_t) ((int64_t) v30); + int64_t v32 = get_block_idx(); + int32_t v33 = (int32_t) ((int64_t) v32); + int32_t v34 = v8 / v31; + int32_t v35 = v8 % v31; + int32_t v36 = (int32_t) ((uint32_t) v34 + (uint32_t) v11); + bool v37 = v33 < v35; + int32_t v38 = v37 ? (int32_t) ((uint32_t) v33 * (uint32_t) v36) : (int32_t) ((uint32_t) ((int32_t) (uint32_t) v35 * (uint32_t) v36) + (uint32_t) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v33 - (uint32_t) v35) * (uint32_t) v34)); + int32_t v39 = (int32_t) ((uint32_t) v33 * (uint32_t) v12); + int32_t v40 = (int32_t) ((uint32_t) v39 + (uint32_t) v28); + __gm__ float* v41 = v1 + v40; + + __gm__ float* v42 = v1 + v39; + + auto v43 = TPipe<0, Direction::DIR_C2V, 131072, 8, 8, true>(v42, v7, v7); + int32_t v44 = (int32_t) ((uint32_t) v39 + (uint32_t) v13); + __gm__ float* v45 = v1 + v44; + + auto v46 = TPipe<2, Direction::DIR_C2V, 65536, 8, 8, true>(v45, v7, v7); + auto v47 = TPipe<4, Direction::DIR_V2C, 65536, 8, 8, false>(v41, v7, v28); + int64_t v48 = get_subblockid(); + int32_t v49 = (int32_t) ((uint32_t) ((int32_t) (int64_t) v48) * (uint32_t) v9); + Tile v50; + TASSIGN(v50, v14); + Tile v51; + TASSIGN(v51, v15); + Tile v52; + TASSIGN(v52, v16); + Tile v53; + TASSIGN(v53, v17); + Tile v54; + TASSIGN(v54, v18); + Tile v55; + TASSIGN(v55, v19); + Tile v56; + TASSIGN(v56, v20); + Tile v57; + TASSIGN(v57, v21); + Tile v58; + TASSIGN(v58, v22); + Tile v59; + TASSIGN(v59, v23); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + for (size_t v60 = (size_t) v38; v60 < ((size_t) ((int32_t) (uint32_t) v38 + (uint32_t) (v37 ? v36 : v34))); v60 += v29) { + Tile v61; + TASSIGN(v61, v26); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v62(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v62); + __gm__ float* v63 = PTOAS__GLOBAL_TENSOR_DATA(v62); + pto::Shape<1, 1, 1, 64, 256> v64 = pto::Shape<1, 1, 1, 64, 256>(); + pto::Stride<16384, 16384, 16384, 256, 1> v65 = pto::Stride<16384, 16384, 16384, 256, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND> v66 = GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND>(v63 + (v6 + (unsigned) v49 * v5), v64, v65); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v61, v66); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v61, v61, v24); + pipe_barrier(PIPE_V); + TROWMAX(v55, v61, v51); + Tile v67; + TRESHAPE(v67, v55); + Tile v68; + TRESHAPE(v68, v54); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(v51, v61, v55); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TMULS(v68, v67, v25); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + pipe_barrier(PIPE_V); + TROWSUM(v56, v51, v50); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v62); + Tile v69; + TASSIGN(v69, v26); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v70(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v70); + __gm__ float* v71 = PTOAS__GLOBAL_TENSOR_DATA(v70); + pto::Shape<1, 1, 1, 64, 256> v72 = pto::Shape<1, 1, 1, 64, 256>(); + pto::Stride<16384, 16384, 16384, 256, 1> v73 = pto::Stride<16384, 16384, 16384, 256, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND> v74 = GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND>(v71 + (v6 + (unsigned) v49 * v5), v72, v73); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v69, v74); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TMULS(v69, v69, v24); + pipe_barrier(PIPE_V); + TROWMAX(v55, v69, v51); + Tile v75; + TRESHAPE(v75, v59); + Tile v76; + TRESHAPE(v76, v56); + Tile v77; + TRESHAPE(v77, v57); + pipe_barrier(PIPE_V); + TMAX(v67, v67, v68); + pipe_barrier(PIPE_V); + TSUB(v75, v68, v67); + pipe_barrier(PIPE_V); + TMULS(v68, v67, v25); + TROWEXPANDSUB(v51, v69, v55); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TEXP(v75, v75); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + TMUL(v76, v76, v75); + pipe_barrier(PIPE_V); + TROWSUM(v57, v51, v50); + pipe_barrier(PIPE_V); + TADD(v76, v76, v77); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v70); + Tile v78; + TASSIGN(v78, v26); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v79(nullptr); + TPOP, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v79); + __gm__ float* v80 = PTOAS__GLOBAL_TENSOR_DATA(v79); + pto::Shape<1, 1, 1, 64, 128> v81 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v82 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v83 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v80 + (v6 + (unsigned) v49 * v3), v81, v82); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v78, v83); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + TMOV(v53, v78); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + TFREE, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v79); + Tile v84; + TASSIGN(v84, v26); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v85(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v85); + __gm__ float* v86 = PTOAS__GLOBAL_TENSOR_DATA(v85); + pto::Shape<1, 1, 1, 64, 256> v87 = pto::Shape<1, 1, 1, 64, 256>(); + pto::Stride<16384, 16384, 16384, 256, 1> v88 = pto::Stride<16384, 16384, 16384, 256, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND> v89 = GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND>(v86 + (v6 + (unsigned) v49 * v5), v87, v88); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + TLOAD(v84, v89); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + TMULS(v84, v84, v24); + pipe_barrier(PIPE_V); + TROWMAX(v55, v84, v51); + Tile v90; + TRESHAPE(v90, v58); + pipe_barrier(PIPE_V); + TMAX(v67, v67, v68); + pipe_barrier(PIPE_V); + TSUB(v90, v68, v67); + pipe_barrier(PIPE_V); + TMULS(v68, v67, v25); + TROWEXPANDSUB(v51, v84, v55); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + TEXP(v90, v90); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + TMUL(v76, v76, v90); + pipe_barrier(PIPE_V); + TROWSUM(v57, v51, v50); + pipe_barrier(PIPE_V); + TADD(v76, v76, v77); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID2); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID2); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v85); + Tile v91; + TASSIGN(v91, v26); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v92(nullptr); + TPOP, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v92); + __gm__ float* v93 = PTOAS__GLOBAL_TENSOR_DATA(v92); + pto::Shape<1, 1, 1, 64, 128> v94 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v95 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v96 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v93 + (v6 + (unsigned) v49 * v3), v94, v95); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + TLOAD(v91, v96); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + TROWEXPANDMUL(v53, v53, v59); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + TADD(v53, v53, v91); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID5); + TFREE, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v92); + Tile v97; + TASSIGN(v97, v26); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v98(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v98); + __gm__ float* v99 = PTOAS__GLOBAL_TENSOR_DATA(v98); + pto::Shape<1, 1, 1, 64, 256> v100 = pto::Shape<1, 1, 1, 64, 256>(); + pto::Stride<16384, 16384, 16384, 256, 1> v101 = pto::Stride<16384, 16384, 16384, 256, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND> v102 = GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND>(v99 + (v6 + (unsigned) v49 * v5), v100, v101); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID5); + TLOAD(v97, v102); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID5); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID5); + TMULS(v97, v97, v24); + pipe_barrier(PIPE_V); + TROWMAX(v55, v97, v51); + pipe_barrier(PIPE_V); + TMAX(v67, v67, v68); + pipe_barrier(PIPE_V); + TSUB(v75, v68, v67); + pipe_barrier(PIPE_V); + TMULS(v68, v67, v25); + TROWEXPANDSUB(v51, v97, v55); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID6); + TEXP(v75, v75); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + TMUL(v76, v76, v75); + pipe_barrier(PIPE_V); + TROWSUM(v57, v51, v50); + pipe_barrier(PIPE_V); + TADD(v76, v76, v77); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID4); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v98); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID6); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID4); + for (size_t v103 = v29; v103 < ((size_t) v27); v103 += v29) { + Tile v104; + TASSIGN(v104, v26); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v105(nullptr); + TPOP, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v105); + __gm__ float* v106 = PTOAS__GLOBAL_TENSOR_DATA(v105); + pto::Shape<1, 1, 1, 64, 128> v107 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v108 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v109 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v106 + (v6 + (unsigned) v49 * v3), v107, v108); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + TLOAD(v104, v109); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID6); + TROWEXPANDMUL(v53, v53, v58); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID6); + TADD(v53, v53, v104); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TFREE, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v105); + Tile v110; + TASSIGN(v110, v26); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v111(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v111); + __gm__ float* v112 = PTOAS__GLOBAL_TENSOR_DATA(v111); + pto::Shape<1, 1, 1, 64, 256> v113 = pto::Shape<1, 1, 1, 64, 256>(); + pto::Stride<16384, 16384, 16384, 256, 1> v114 = pto::Stride<16384, 16384, 16384, 256, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND> v115 = GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND>(v112 + (v6 + (unsigned) v49 * v5), v113, v114); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v110, v115); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v110, v110, v24); + pipe_barrier(PIPE_V); + TROWMAX(v55, v110, v51); + pipe_barrier(PIPE_V); + TMAX(v67, v67, v68); + pipe_barrier(PIPE_V); + TSUB(v90, v68, v67); + pipe_barrier(PIPE_V); + TMULS(v68, v67, v25); + TROWEXPANDSUB(v51, v110, v55); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TEXP(v90, v90); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + TMUL(v76, v76, v90); + pipe_barrier(PIPE_V); + TROWSUM(v57, v51, v50); + pipe_barrier(PIPE_V); + TADD(v76, v76, v77); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID4); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID4); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v111); + Tile v116; + TASSIGN(v116, v26); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v117(nullptr); + TPOP, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v117); + __gm__ float* v118 = PTOAS__GLOBAL_TENSOR_DATA(v117); + pto::Shape<1, 1, 1, 64, 128> v119 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v120 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v121 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v118 + (v6 + (unsigned) v49 * v3), v119, v120); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v116, v121); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TROWEXPANDMUL(v53, v53, v59); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v53, v53, v116); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TFREE, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v117); + Tile v122; + TASSIGN(v122, v26); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v123(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v123); + __gm__ float* v124 = PTOAS__GLOBAL_TENSOR_DATA(v123); + pto::Shape<1, 1, 1, 64, 256> v125 = pto::Shape<1, 1, 1, 64, 256>(); + pto::Stride<16384, 16384, 16384, 256, 1> v126 = pto::Stride<16384, 16384, 16384, 256, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND> v127 = GlobalTensor, pto::Stride<16384, 16384, 16384, 256, 1>, pto::Layout::ND>(v124 + (v6 + (unsigned) v49 * v5), v125, v126); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v122, v127); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v122, v122, v24); + pipe_barrier(PIPE_V); + TROWMAX(v55, v122, v51); + pipe_barrier(PIPE_V); + TMAX(v67, v67, v68); + pipe_barrier(PIPE_V); + TSUB(v75, v68, v67); + pipe_barrier(PIPE_V); + TMULS(v68, v67, v25); + TROWEXPANDSUB(v51, v122, v55); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + TEXP(v75, v75); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + TMUL(v76, v76, v75); + pipe_barrier(PIPE_V); + TROWSUM(v57, v51, v50); + pipe_barrier(PIPE_V); + TADD(v76, v76, v77); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID5); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID5); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v123); + }; + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v128; + TASSIGN(v128, v26); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v129(nullptr); + TPOP, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v129); + __gm__ float* v130 = PTOAS__GLOBAL_TENSOR_DATA(v129); + pto::Shape<1, 1, 1, 64, 128> v131 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v132 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v133 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v130 + (v6 + (unsigned) v49 * v3), v131, v132); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v128, v133); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TROWEXPANDMUL(v53, v53, v58); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v53, v53, v128); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TFREE, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v129); + Tile v134; + TASSIGN(v134, v26); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v135(nullptr); + TPOP, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v135); + __gm__ float* v136 = PTOAS__GLOBAL_TENSOR_DATA(v135); + pto::Shape<1, 1, 1, 64, 128> v137 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v138 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v139 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v136 + (v6 + (unsigned) v49 * v3), v137, v138); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v134, v139); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TROWEXPANDMUL(v53, v53, v59); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v53, v53, v134); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TFREE, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v46, v135); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v53, v53, v56); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID6); + pto::Shape<1, 1, 1, 64, 128> v140 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v141 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v142 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v2 + (v6 + (unsigned) ((int32_t) (uint32_t) ((int32_t) (uint32_t) ((int32_t) v60) * (uint32_t) v10) + (uint32_t) v49) * (unsigned) v10 + v6 * (unsigned) v11), v140, v141); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID6); + TSTORE(v142, v53); + } + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +__global__ AICORE void call_both(__gm__ int64_t* v1, __gm__ float* v2, __gm__ half* v3, __gm__ half* v4, __gm__ half* v5, __gm__ float* v6) { + using T = float; + uint64_t v7 = (uint64_t) v1; + set_ffts_base_addr(v7); + cube_kernel(v2, v3, v4, v5); + vector_kernel(v2, v6); + return; +} diff --git a/examples/aot/flash_attention/ir_ref/fa.pto b/examples/aot/flash_attention/ir_ref/fa.pto new file mode 100644 index 00000000..1fb6b357 --- /dev/null +++ b/examples/aot/flash_attention/ir_ref/fa.pto @@ -0,0 +1,501 @@ +// RUN: ptoas --pto-arch=a3 --pto-level=level3 --enable-insert-sync %s >/dev/null + +module { + func.func @cube_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c128_0 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %c128_1 = arith.constant 128 : index + %c16 = arith.constant 16 : index + %c16_2 = arith.constant 16 : index + %0 = pto.get_block_num + %1 = arith.index_cast %0 : i64 to index + %2 = pto.get_block_idx + %3 = arith.index_cast %2 : i64 to index + %4 = arith.divsi %c16_2, %1 : index + %5 = arith.remsi %c16_2, %1 : index + %6 = arith.addi %4, %c1 : index + %7 = arith.muli %3, %6 : index + %8 = arith.addi %4, %c1 : index + %9 = arith.muli %5, %8 : index + %10 = arith.subi %3, %5 : index + %11 = arith.muli %10, %4 : index + %12 = arith.addi %9, %11 : index + %13 = arith.cmpi slt, %3, %5 : index + %14 = arith.select %13, %7, %12 : index + %15 = arith.cmpi slt, %3, %5 : index + %16 = arith.addi %4, %c1 : index + %17 = arith.select %15, %16, %4 : index + %18 = arith.addi %14, %17 : index + %c524288 = arith.constant 524288 : index + %19 = arith.muli %3, %c524288 : index + %20 = pto.addptr %arg0, %19 : -> + %c0_3 = arith.constant 0 : index + %21 = pto.addptr %20, %c0_3 : -> + %c262144 = arith.constant 262144 : index + %22 = pto.addptr %20, %c262144 : -> + %c393216 = arith.constant 393216 : index + %23 = pto.addptr %20, %c393216 : -> + %qk_slot_desc = pto.make_tensor_view %21, shape = [%c128, %c256], strides = [%c256, %c1] : !pto.tensor_view<128x256xf32> + pto.aic_initialize_pipe{id = 25, dir_mask = 1, slot_size = 131072} (gm_slot_tensor = %qk_slot_desc : !pto.tensor_view<128x256xf32>) + %pv_slot_desc = pto.make_tensor_view %22, shape = [%c128, %c128_0], strides = [%c128_0, %c1] : !pto.tensor_view<128x128xf32> + pto.aic_initialize_pipe{id = 27, dir_mask = 1, slot_size = 65536} (gm_slot_tensor = %pv_slot_desc : !pto.tensor_view<128x128xf32>) + %28 = pto.reserve_buffer{name = "fa_p_v2c_fifo", size = 524288, location = , auto = false, base = 393216} -> i32 + %c0_i32 = arith.constant 0 : i32 + pto.aic_initialize_pipe{id = 30, dir_mask = 2, slot_size = 65536, nosplit = false} (gm_slot_buffer = %23 : !pto.ptr, c2v_consumer_buf = %c0_i32 : i32, v2c_consumer_buf = %28 : i32) + %c0_i64 = arith.constant 0 : i64 + %c0_i64_4 = arith.constant 0 : i64 + %29 = pto.alloc_tile addr = %c0_i64_4 : !pto.tile_buf + %c0_i64_5 = arith.constant 0 : i64 + %30 = pto.alloc_tile addr = %c0_i64_5 : !pto.tile_buf + %c32768_i64 = arith.constant 32768 : i64 + %31 = pto.alloc_tile addr = %c32768_i64 : !pto.tile_buf + %c65536_i64 = arith.constant 65536 : i64 + %32 = pto.alloc_tile addr = %c65536_i64 : !pto.tile_buf + %33 = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %c0_i64_6 = arith.constant 0 : i64 + %34 = pto.alloc_tile addr = %c0_i64_6 : !pto.tile_buf + %c98304_i64 = arith.constant 98304 : i64 + %35 = pto.alloc_tile addr = %c98304_i64 : !pto.tile_buf + %c32768_i64_7 = arith.constant 32768 : i64 + %36 = pto.alloc_tile addr = %c32768_i64_7 : !pto.tile_buf + %c163840_i64 = arith.constant 163840 : i64 + %37 = pto.alloc_tile addr = %c163840_i64 : !pto.tile_buf + %38 = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %c131072_i64 = arith.constant 131072 : i64 + %39 = pto.alloc_tile addr = %c131072_i64 : !pto.tile_buf + %c2048 = arith.constant 2048 : index + %40 = pto.make_tensor_view %arg1, shape = [%c2048, %c128_0], strides = [%c128_0, %c1] : !pto.tensor_view + %41 = pto.make_tensor_view %arg2, shape = [%c128_0, %c4096], strides = [%c1, %c128_0] : !pto.tensor_view + %42 = pto.make_tensor_view %arg3, shape = [%c4096, %c128_0], strides = [%c128_0, %c1] : !pto.tensor_view + scf.for %arg4 = %14 to %18 step %c1 { + %43 = arith.muli %arg4, %c128 : index + %44 = pto.partition_view %40, offsets = [%43, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%44 : !pto.partition_tensor_view<128x128xf16>) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%30 : !pto.tile_buf) + %c0_8 = arith.constant 0 : index + %c0_9 = arith.constant 0 : index + %45 = arith.addi %c0_8, %c0_9 : index + %46 = pto.partition_view %41, offsets = [%c0, %45], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%46 : !pto.partition_tensor_view<128x128xf16>) outs(%31 : !pto.tile_buf) + pto.tmov ins(%31 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c0_10 = arith.constant 0 : index + %47 = pto.subview %34[%c0, %c0_10] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%47 : !pto.tile_buf) + %c128_11 = arith.constant 128 : index + %48 = arith.addi %c0_8, %c128_11 : index + %49 = pto.partition_view %41, offsets = [%c0, %48], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%49 : !pto.partition_tensor_view<128x128xf16>) outs(%31 : !pto.tile_buf) + pto.tmov ins(%31 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c128_12 = arith.constant 128 : index + %50 = pto.subview %34[%c0, %c128_12] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%50 : !pto.tile_buf) + %qk_push_0 = pto.talloc_to_aiv {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_push_part_0 = pto.partition_view %qk_push_0, offsets = [%c0, %c0], sizes = [%c128, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<128x256xf32> + pto.tstore ins(%34 : !pto.tile_buf) outs(%qk_push_part_0 : !pto.partition_tensor_view<128x256xf32>) + pto.tpush_to_aiv(%qk_push_0 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + %c256_13 = arith.constant 256 : index + %c0_14 = arith.constant 0 : index + %51 = arith.addi %c256_13, %c0_14 : index + %52 = pto.partition_view %41, offsets = [%c0, %51], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%52 : !pto.partition_tensor_view<128x128xf16>) outs(%32 : !pto.tile_buf) + pto.tmov ins(%32 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c0_15 = arith.constant 0 : index + %53 = pto.subview %34[%c0, %c0_15] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%53 : !pto.tile_buf) + %c128_16 = arith.constant 128 : index + %54 = arith.addi %c256_13, %c128_16 : index + %55 = pto.partition_view %41, offsets = [%c0, %54], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%55 : !pto.partition_tensor_view<128x128xf16>) outs(%32 : !pto.tile_buf) + pto.tmov ins(%32 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c128_17 = arith.constant 128 : index + %56 = pto.subview %34[%c0, %c128_17] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + %qk_push_1 = pto.talloc_to_aiv {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_push_part_1 = pto.partition_view %qk_push_1, offsets = [%c0, %c0], sizes = [%c128, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<128x256xf32> + pto.tstore ins(%34 : !pto.tile_buf) outs(%qk_push_part_1 : !pto.partition_tensor_view<128x256xf32>) + pto.tpush_to_aiv(%qk_push_1 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + %57 = pto.partition_view %42, offsets = [%c0, %c0], sizes = [%c256, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.tload ins(%57 : !pto.partition_tensor_view<256x128xf16>) outs(%37 : !pto.tile_buf) + %c2 = arith.constant 2 : index + %c7 = arith.constant 7 : index + scf.for %arg5 = %c0 to %c7 step %c1 { + %61 = arith.muli %arg5, %c2 : index + %c2_18 = arith.constant 2 : index + %62 = arith.addi %61, %c2_18 : index + %63 = arith.muli %62, %c256 : index + %64 = pto.tpop_from_aiv {id = 30, split = 1} -> !pto.tile_buf + pto.tmov ins(%64 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + pto.tfree_from_aiv{id = 30, split = 1} + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + %65 = arith.addi %61, %c1 : index + %66 = arith.muli %65, %c256 : index + %67 = pto.partition_view %42, offsets = [%66, %c0], sizes = [%c256, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.tload ins(%67 : !pto.partition_tensor_view<256x128xf16>) outs(%37 : !pto.tile_buf) + pto.tmatmul ins(%36, %38 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %pv_push_0 = pto.talloc_to_aiv {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_push_part_0 = pto.partition_view %pv_push_0, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<128x128xf32> + pto.tstore ins(%39 : !pto.tile_buf) outs(%pv_push_part_0 : !pto.partition_tensor_view<128x128xf32>) + pto.tpush_to_aiv(%pv_push_0 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %c0_19 = arith.constant 0 : index + %68 = arith.addi %63, %c0_19 : index + %69 = pto.partition_view %41, offsets = [%c0, %68], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%69 : !pto.partition_tensor_view<128x128xf16>) outs(%31 : !pto.tile_buf) + pto.tmov ins(%31 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c0_20 = arith.constant 0 : index + %70 = pto.subview %34[%c0, %c0_20] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%70 : !pto.tile_buf) + %c128_21 = arith.constant 128 : index + %71 = arith.addi %63, %c128_21 : index + %72 = pto.partition_view %41, offsets = [%c0, %71], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%72 : !pto.partition_tensor_view<128x128xf16>) outs(%31 : !pto.tile_buf) + pto.tmov ins(%31 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c128_22 = arith.constant 128 : index + %73 = pto.subview %34[%c0, %c128_22] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%73 : !pto.tile_buf) + %qk_push_2 = pto.talloc_to_aiv {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_push_part_2 = pto.partition_view %qk_push_2, offsets = [%c0, %c0], sizes = [%c128, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<128x256xf32> + pto.tstore ins(%34 : !pto.tile_buf) outs(%qk_push_part_2 : !pto.partition_tensor_view<128x256xf32>) + pto.tpush_to_aiv(%qk_push_2 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + %74 = arith.muli %arg5, %c2 : index + %75 = arith.addi %74, %c1 : index + %c2_23 = arith.constant 2 : index + %76 = arith.addi %75, %c2_23 : index + %77 = arith.muli %76, %c256 : index + %78 = pto.tpop_from_aiv {id = 30, split = 1} -> !pto.tile_buf + pto.tmov ins(%78 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + pto.tfree_from_aiv{id = 30, split = 1} + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + %79 = arith.addi %75, %c1 : index + %80 = arith.muli %79, %c256 : index + %81 = pto.partition_view %42, offsets = [%80, %c0], sizes = [%c256, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.tload ins(%81 : !pto.partition_tensor_view<256x128xf16>) outs(%37 : !pto.tile_buf) + pto.tmatmul ins(%36, %38 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %pv_push_1 = pto.talloc_to_aiv {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_push_part_1 = pto.partition_view %pv_push_1, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<128x128xf32> + pto.tstore ins(%39 : !pto.tile_buf) outs(%pv_push_part_1 : !pto.partition_tensor_view<128x128xf32>) + pto.tpush_to_aiv(%pv_push_1 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %c0_24 = arith.constant 0 : index + %82 = arith.addi %77, %c0_24 : index + %83 = pto.partition_view %41, offsets = [%c0, %82], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%83 : !pto.partition_tensor_view<128x128xf16>) outs(%32 : !pto.tile_buf) + pto.tmov ins(%32 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c0_25 = arith.constant 0 : index + %84 = pto.subview %34[%c0, %c0_25] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%84 : !pto.tile_buf) + %c128_26 = arith.constant 128 : index + %85 = arith.addi %77, %c128_26 : index + %86 = pto.partition_view %41, offsets = [%c0, %85], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%86 : !pto.partition_tensor_view<128x128xf16>) outs(%32 : !pto.tile_buf) + pto.tmov ins(%32 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c128_27 = arith.constant 128 : index + %87 = pto.subview %34[%c0, %c128_27] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%87 : !pto.tile_buf) + %qk_push_3 = pto.talloc_to_aiv {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_push_part_3 = pto.partition_view %qk_push_3, offsets = [%c0, %c0], sizes = [%c128, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<128x256xf32> + pto.tstore ins(%34 : !pto.tile_buf) outs(%qk_push_part_3 : !pto.partition_tensor_view<128x256xf32>) + pto.tpush_to_aiv(%qk_push_3 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + } + %58 = pto.tpop_from_aiv {id = 30, split = 1} -> !pto.tile_buf + pto.tmov ins(%58 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + pto.tfree_from_aiv{id = 30, split = 1} + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + %c3840 = arith.constant 3840 : index + %59 = pto.partition_view %42, offsets = [%c3840, %c0], sizes = [%c256, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<256x128xf16> + pto.tload ins(%59 : !pto.partition_tensor_view<256x128xf16>) outs(%37 : !pto.tile_buf) + pto.tmatmul ins(%36, %38 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %pv_push_2 = pto.talloc_to_aiv {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_push_part_2 = pto.partition_view %pv_push_2, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<128x128xf32> + pto.tstore ins(%39 : !pto.tile_buf) outs(%pv_push_part_2 : !pto.partition_tensor_view<128x128xf32>) + pto.tpush_to_aiv(%pv_push_2 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %60 = pto.tpop_from_aiv {id = 30, split = 1} -> !pto.tile_buf + pto.tmov ins(%60 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + pto.tfree_from_aiv{id = 30, split = 1} + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tmatmul ins(%36, %38 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %pv_push_3 = pto.talloc_to_aiv {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_push_part_3 = pto.partition_view %pv_push_3, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<128x128xf32> + pto.tstore ins(%39 : !pto.tile_buf) outs(%pv_push_part_3 : !pto.partition_tensor_view<128x128xf32>) + pto.tpush_to_aiv(%pv_push_3 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + } + return + } + func.func @vector_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c256 = arith.constant 256 : index + %c128_0 = arith.constant 128 : index + %c16 = arith.constant 16 : index + %c16_1 = arith.constant 16 : index + %0 = pto.get_block_num + %1 = arith.index_cast %0 : i64 to index + %2 = pto.get_block_idx + %3 = arith.index_cast %2 : i64 to index + %4 = arith.divsi %c16_1, %1 : index + %5 = arith.remsi %c16_1, %1 : index + %6 = arith.addi %4, %c1 : index + %7 = arith.muli %3, %6 : index + %8 = arith.addi %4, %c1 : index + %9 = arith.muli %5, %8 : index + %10 = arith.subi %3, %5 : index + %11 = arith.muli %10, %4 : index + %12 = arith.addi %9, %11 : index + %13 = arith.cmpi slt, %3, %5 : index + %14 = arith.select %13, %7, %12 : index + %15 = arith.cmpi slt, %3, %5 : index + %16 = arith.addi %4, %c1 : index + %17 = arith.select %15, %16, %4 : index + %18 = arith.addi %14, %17 : index + %c524288 = arith.constant 524288 : index + %19 = arith.muli %3, %c524288 : index + %20 = pto.addptr %arg0, %19 : -> + %c0_2 = arith.constant 0 : index + %21 = pto.addptr %20, %c0_2 : -> + %c262144 = arith.constant 262144 : index + %22 = pto.addptr %20, %c262144 : -> + %c393216 = arith.constant 393216 : index + %23 = pto.addptr %20, %c393216 : -> + %qk_slot_desc = pto.make_tensor_view %21, shape = [%c128, %c256], strides = [%c256, %c1] : !pto.tensor_view<128x256xf32> + pto.aiv_initialize_pipe{id = 25, dir_mask = 1, slot_size = 131072} (gm_slot_tensor = %qk_slot_desc : !pto.tensor_view<128x256xf32>) + %pv_slot_desc = pto.make_tensor_view %22, shape = [%c128, %c128_0], strides = [%c128_0, %c1] : !pto.tensor_view<128x128xf32> + pto.aiv_initialize_pipe{id = 27, dir_mask = 1, slot_size = 65536} (gm_slot_tensor = %pv_slot_desc : !pto.tensor_view<128x128xf32>) + %28 = pto.import_reserved_buffer{name = "fa_p_v2c_fifo", peer_func = @cube_kernel} -> i32 + %c0_i32 = arith.constant 0 : i32 + pto.aiv_initialize_pipe{id = 30, dir_mask = 2, slot_size = 65536, nosplit = false} (gm_slot_buffer = %23 : !pto.ptr, c2v_consumer_buf = %c0_i32 : i32, v2c_consumer_buf = %28 : i32) + %29 = pto.get_subblock_idx + %30 = arith.index_cast %29 : i64 to index + %31 = arith.muli %30, %c64 : index + %c196608_i64 = arith.constant 196608 : i64 + %32 = pto.alloc_tile addr = %c196608_i64 : !pto.tile_buf + %c262144_i64 = arith.constant 262144 : i64 + %33 = pto.alloc_tile addr = %c262144_i64 : !pto.tile_buf + %c327680_i64 = arith.constant 327680 : i64 + %34 = pto.alloc_tile addr = %c327680_i64 : !pto.tile_buf + %c360448_i64 = arith.constant 360448 : i64 + %35 = pto.alloc_tile addr = %c360448_i64 : !pto.tile_buf + %c393216_i64 = arith.constant 393216 : i64 + %36 = pto.alloc_tile addr = %c393216_i64 : !pto.tile_buf + %c393472_i64 = arith.constant 393472 : i64 + %37 = pto.alloc_tile addr = %c393472_i64 : !pto.tile_buf + %c393728_i64 = arith.constant 393728 : i64 + %38 = pto.alloc_tile addr = %c393728_i64 : !pto.tile_buf + %c393984_i64 = arith.constant 393984 : i64 + %39 = pto.alloc_tile addr = %c393984_i64 : !pto.tile_buf + %c394240_i64 = arith.constant 394240 : i64 + %40 = pto.alloc_tile addr = %c394240_i64 : !pto.tile_buf + %c394496_i64 = arith.constant 394496 : i64 + %41 = pto.alloc_tile addr = %c394496_i64 : !pto.tile_buf + %cst = arith.constant 0.0883883461 : f32 + %cst_3 = arith.constant 1.000000e+00 : f32 + %c2048 = arith.constant 2048 : index + %42 = pto.make_tensor_view %arg1, shape = [%c2048, %c128_0], strides = [%c128_0, %c1] : !pto.tensor_view + scf.for %arg2 = %14 to %18 step %c1 { + %43 = arith.muli %arg2, %c128 : index + %c394752_i64 = arith.constant 394752 : i64 + %44 = pto.alloc_tile addr = %c394752_i64 : !pto.tile_buf + %qk_pop_0 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_0 = pto.partition_view %qk_pop_0, offsets = [%31, %c0], sizes = [%c64, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<64x256xf32> + pto.tload ins(%qk_pop_part_0 : !pto.partition_tensor_view<64x256xf32>) outs(%44 : !pto.tile_buf) + pto.tmuls ins(%44, %cst : !pto.tile_buf, f32) outs(%44 : !pto.tile_buf) + pto.trowmax ins(%44, %33 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %45 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %46 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %47 = pto.treshape %40 : !pto.tile_buf -> !pto.tile_buf + %48 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %49 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.trowexpandsub ins(%44, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmuls ins(%45, %cst_3 : !pto.tile_buf, f32) outs(%46 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree_from_aic(%qk_pop_0 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + %c394752_i64_4 = arith.constant 394752 : i64 + %50 = pto.alloc_tile addr = %c394752_i64_4 : !pto.tile_buf + %qk_pop_1 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_1 = pto.partition_view %qk_pop_1, offsets = [%31, %c0], sizes = [%c64, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<64x256xf32> + pto.tload ins(%qk_pop_part_1 : !pto.partition_tensor_view<64x256xf32>) outs(%50 : !pto.tile_buf) + pto.tmuls ins(%50, %cst : !pto.tile_buf, f32) outs(%50 : !pto.tile_buf) + pto.trowmax ins(%50, %33 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %51 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %52 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %53 = pto.treshape %41 : !pto.tile_buf -> !pto.tile_buf + %54 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %55 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%51, %52 : !pto.tile_buf, !pto.tile_buf) outs(%51 : !pto.tile_buf) + pto.tsub ins(%52, %51 : !pto.tile_buf, !pto.tile_buf) outs(%53 : !pto.tile_buf) + pto.tmuls ins(%51, %cst_3 : !pto.tile_buf, f32) outs(%52 : !pto.tile_buf) + pto.trowexpandsub ins(%50, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%53 : !pto.tile_buf) outs(%53 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%54, %53 : !pto.tile_buf, !pto.tile_buf) outs(%54 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.tadd ins(%54, %55 : !pto.tile_buf, !pto.tile_buf) outs(%54 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree_from_aic(%qk_pop_1 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + %c394752_i64_5 = arith.constant 394752 : i64 + %56 = pto.alloc_tile addr = %c394752_i64_5 : !pto.tile_buf + %pv_pop_0 = pto.tpop_from_aic {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_pop_part_0 = pto.partition_view %pv_pop_0, offsets = [%31, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_0 : !pto.partition_tensor_view<64x128xf32>) outs(%56 : !pto.tile_buf) + pto.tmov ins(%56 : !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_0 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %c394752_i64_6 = arith.constant 394752 : i64 + %57 = pto.alloc_tile addr = %c394752_i64_6 : !pto.tile_buf + %qk_pop_2 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_2 = pto.partition_view %qk_pop_2, offsets = [%31, %c0], sizes = [%c64, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<64x256xf32> + pto.tload ins(%qk_pop_part_2 : !pto.partition_tensor_view<64x256xf32>) outs(%57 : !pto.tile_buf) + pto.tmuls ins(%57, %cst : !pto.tile_buf, f32) outs(%57 : !pto.tile_buf) + pto.trowmax ins(%57, %33 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %58 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %59 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %60 = pto.treshape %40 : !pto.tile_buf -> !pto.tile_buf + %61 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %62 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%58, %59 : !pto.tile_buf, !pto.tile_buf) outs(%58 : !pto.tile_buf) + pto.tsub ins(%59, %58 : !pto.tile_buf, !pto.tile_buf) outs(%60 : !pto.tile_buf) + pto.tmuls ins(%58, %cst_3 : !pto.tile_buf, f32) outs(%59 : !pto.tile_buf) + pto.trowexpandsub ins(%57, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%60 : !pto.tile_buf) outs(%60 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%61, %60 : !pto.tile_buf, !pto.tile_buf) outs(%61 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.tadd ins(%61, %62 : !pto.tile_buf, !pto.tile_buf) outs(%61 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree_from_aic(%qk_pop_2 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + %c394752_i64_7 = arith.constant 394752 : i64 + %63 = pto.alloc_tile addr = %c394752_i64_7 : !pto.tile_buf + %pv_pop_1 = pto.tpop_from_aic {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_pop_part_1 = pto.partition_view %pv_pop_1, offsets = [%31, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_1 : !pto.partition_tensor_view<64x128xf32>) outs(%63 : !pto.tile_buf) + pto.trowexpandmul ins(%35, %41 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %63 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_1 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %c394752_i64_8 = arith.constant 394752 : i64 + %64 = pto.alloc_tile addr = %c394752_i64_8 : !pto.tile_buf + %qk_pop_3 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_3 = pto.partition_view %qk_pop_3, offsets = [%31, %c0], sizes = [%c64, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<64x256xf32> + pto.tload ins(%qk_pop_part_3 : !pto.partition_tensor_view<64x256xf32>) outs(%64 : !pto.tile_buf) + pto.tmuls ins(%64, %cst : !pto.tile_buf, f32) outs(%64 : !pto.tile_buf) + pto.trowmax ins(%64, %33 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %65 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %66 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %67 = pto.treshape %41 : !pto.tile_buf -> !pto.tile_buf + %68 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %69 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%65, %66 : !pto.tile_buf, !pto.tile_buf) outs(%65 : !pto.tile_buf) + pto.tsub ins(%66, %65 : !pto.tile_buf, !pto.tile_buf) outs(%67 : !pto.tile_buf) + pto.tmuls ins(%65, %cst_3 : !pto.tile_buf, f32) outs(%66 : !pto.tile_buf) + pto.trowexpandsub ins(%64, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%67 : !pto.tile_buf) outs(%67 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%68, %67 : !pto.tile_buf, !pto.tile_buf) outs(%68 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.tadd ins(%68, %69 : !pto.tile_buf, !pto.tile_buf) outs(%68 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree_from_aic(%qk_pop_3 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + %c7 = arith.constant 7 : index + scf.for %arg3 = %c1 to %c7 step %c1 { + %c394752_i64_11 = arith.constant 394752 : i64 + %74 = pto.alloc_tile addr = %c394752_i64_11 : !pto.tile_buf + %pv_pop_2 = pto.tpop_from_aic {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_pop_part_2 = pto.partition_view %pv_pop_2, offsets = [%31, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_2 : !pto.partition_tensor_view<64x128xf32>) outs(%74 : !pto.tile_buf) + pto.trowexpandmul ins(%35, %40 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %74 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_2 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %c394752_i64_12 = arith.constant 394752 : i64 + %75 = pto.alloc_tile addr = %c394752_i64_12 : !pto.tile_buf + %qk_pop_4 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_4 = pto.partition_view %qk_pop_4, offsets = [%31, %c0], sizes = [%c64, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<64x256xf32> + pto.tload ins(%qk_pop_part_4 : !pto.partition_tensor_view<64x256xf32>) outs(%75 : !pto.tile_buf) + pto.tmuls ins(%75, %cst : !pto.tile_buf, f32) outs(%75 : !pto.tile_buf) + pto.trowmax ins(%75, %33 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %76 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %77 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %78 = pto.treshape %40 : !pto.tile_buf -> !pto.tile_buf + %79 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %80 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%76, %77 : !pto.tile_buf, !pto.tile_buf) outs(%76 : !pto.tile_buf) + pto.tsub ins(%77, %76 : !pto.tile_buf, !pto.tile_buf) outs(%78 : !pto.tile_buf) + pto.tmuls ins(%76, %cst_3 : !pto.tile_buf, f32) outs(%77 : !pto.tile_buf) + pto.trowexpandsub ins(%75, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%78 : !pto.tile_buf) outs(%78 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%79, %78 : !pto.tile_buf, !pto.tile_buf) outs(%79 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.tadd ins(%79, %80 : !pto.tile_buf, !pto.tile_buf) outs(%79 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree_from_aic(%qk_pop_4 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + %c394752_i64_13 = arith.constant 394752 : i64 + %81 = pto.alloc_tile addr = %c394752_i64_13 : !pto.tile_buf + %pv_pop_3 = pto.tpop_from_aic {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_pop_part_3 = pto.partition_view %pv_pop_3, offsets = [%31, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_3 : !pto.partition_tensor_view<64x128xf32>) outs(%81 : !pto.tile_buf) + pto.trowexpandmul ins(%35, %41 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %81 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_3 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %c394752_i64_14 = arith.constant 394752 : i64 + %82 = pto.alloc_tile addr = %c394752_i64_14 : !pto.tile_buf + %qk_pop_5 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_5 = pto.partition_view %qk_pop_5, offsets = [%31, %c0], sizes = [%c64, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<64x256xf32> + pto.tload ins(%qk_pop_part_5 : !pto.partition_tensor_view<64x256xf32>) outs(%82 : !pto.tile_buf) + pto.tmuls ins(%82, %cst : !pto.tile_buf, f32) outs(%82 : !pto.tile_buf) + pto.trowmax ins(%82, %33 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %83 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %84 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %85 = pto.treshape %41 : !pto.tile_buf -> !pto.tile_buf + %86 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %87 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%83, %84 : !pto.tile_buf, !pto.tile_buf) outs(%83 : !pto.tile_buf) + pto.tsub ins(%84, %83 : !pto.tile_buf, !pto.tile_buf) outs(%85 : !pto.tile_buf) + pto.tmuls ins(%83, %cst_3 : !pto.tile_buf, f32) outs(%84 : !pto.tile_buf) + pto.trowexpandsub ins(%82, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%85 : !pto.tile_buf) outs(%85 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%86, %85 : !pto.tile_buf, !pto.tile_buf) outs(%86 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.tadd ins(%86, %87 : !pto.tile_buf, !pto.tile_buf) outs(%86 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree_from_aic(%qk_pop_5 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + } + %c394752_i64_9 = arith.constant 394752 : i64 + %70 = pto.alloc_tile addr = %c394752_i64_9 : !pto.tile_buf + %pv_pop_4 = pto.tpop_from_aic {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_pop_part_4 = pto.partition_view %pv_pop_4, offsets = [%31, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_4 : !pto.partition_tensor_view<64x128xf32>) outs(%70 : !pto.tile_buf) + pto.trowexpandmul ins(%35, %40 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %70 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_4 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %c394752_i64_10 = arith.constant 394752 : i64 + %71 = pto.alloc_tile addr = %c394752_i64_10 : !pto.tile_buf + %pv_pop_5 = pto.tpop_from_aic {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_pop_part_5 = pto.partition_view %pv_pop_5, offsets = [%31, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_5 : !pto.partition_tensor_view<64x128xf32>) outs(%71 : !pto.tile_buf) + pto.trowexpandmul ins(%35, %41 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %71 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_5 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + pto.trowexpanddiv ins(%35, %38 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + %72 = arith.addi %43, %31 : index + %73 = pto.partition_view %42, offsets = [%72, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<64x128xf32> + pto.tstore ins(%35 : !pto.tile_buf) outs(%73 : !pto.partition_tensor_view<64x128xf32>) + } + return + } + func.func @call_both(%arg0: memref<256xi64>, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %arg4: !pto.ptr, %arg5: !pto.ptr) attributes {pto.entry} { + pto.set_ffts %arg0 : memref<256xi64> + call @cube_kernel(%arg1, %arg2, %arg3, %arg4) : (!pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr) -> () + call @vector_kernel(%arg1, %arg5) : (!pto.ptr, !pto.ptr) -> () + return + } +} \ No newline at end of file diff --git a/examples/aot/flash_attention/ir_ref/fa_perf.cpp b/examples/aot/flash_attention/ir_ref/fa_perf.cpp new file mode 100644 index 00000000..b67d8ac7 --- /dev/null +++ b/examples/aot/flash_attention/ir_ref/fa_perf.cpp @@ -0,0 +1,1452 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +template +static AICORE inline auto PTOAS__GLOBAL_TENSOR_DATA(Tensor &tensor) + -> decltype(tensor.data()) { + return tensor.data(); +} + + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +AICORE void cube_kernel(__gm__ float* v1, __gm__ half* v2, __gm__ half* v3, __gm__ half* v4, __gm__ half* v5, __gm__ float* v6) { + unsigned v7 = 3968; + unsigned v8 = 3840; + unsigned v9 = 3712; + unsigned v10 = 3584; + unsigned v11 = 384; + unsigned v12 = 256; + unsigned v13 = 128; + unsigned v14 = 16; + unsigned v15 = 0; + const int32_t v16 = 16; + const int32_t v17 = 256; + const int32_t v18 = 128; + const int32_t v19 = 1; + const int32_t v20 = 0; + const int32_t v21 = 262144; + const int32_t v22 = 131072; + const int64_t v23 = 0; + const int64_t v24 = 32768; + const int64_t v25 = 65536; + const int64_t v26 = 98304; + const int64_t v27 = 229376; + const int32_t v28 = 2; + const int32_t v29 = 7; + using T = float; + + #if defined(__DAV_CUBE__) + size_t v30 = (size_t) v19; + int64_t v31 = get_block_num(); + int32_t v32 = (int32_t) ((int64_t) v31); + int64_t v33 = get_block_idx(); + int32_t v34 = (int32_t) ((int64_t) v33); + int32_t v35 = v16 / v32; + int32_t v36 = v16 % v32; + int32_t v37 = (int32_t) ((uint32_t) v35 + (uint32_t) v19); + bool v38 = v34 < v36; + int32_t v39 = v38 ? (int32_t) ((uint32_t) v34 * (uint32_t) v37) : (int32_t) ((uint32_t) ((int32_t) (uint32_t) v36 * (uint32_t) v37) + (uint32_t) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v34 - (uint32_t) v36) * (uint32_t) v35)); + int32_t v40 = (int32_t) ((uint32_t) v34 * (uint32_t) v21); + int32_t v41 = (int32_t) ((uint32_t) v34 * (uint32_t) v22); + __gm__ float* v42 = v1 + v40; + + auto v43 = TPipe<0, Direction::DIR_C2V, 131072, 8, 8, true>(v42, v20, v20); + __gm__ half* v44 = v5 + v40; + + auto v45 = TPipe<2, Direction::DIR_V2C, 65536, 8, 8, true>(v44, v20, v20); + __gm__ float* v46 = v6 + v41; + + auto v47 = TPipe<4, Direction::DIR_C2V, 65536, 8, 8, true>(v46, v20, v20); + Tile v48; + TASSIGN(v48, v23); + Tile v49; + TASSIGN(v49, v23); + Tile v50; + TASSIGN(v50, v24); + Tile v51; + TASSIGN(v51, v25); + Tile v52; + TASSIGN(v52, v23); + Tile v53; + TASSIGN(v53, v23); + Tile v54; + TASSIGN(v54, v26); + Tile v55; + TASSIGN(v55, v24); + Tile v56; + TASSIGN(v56, v27); + Tile v57; + TASSIGN(v57, v23); + Tile v58; + TASSIGN(v58, v25); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + for (size_t v59 = (size_t) v39; v59 < ((size_t) ((int32_t) (uint32_t) v39 + (uint32_t) (v38 ? v37 : v35))); v59 += v30) { + pto::Shape<1, 1, 1, 128, 128> v60 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v61 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v62 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v2 + (v15 + (unsigned) ((int32_t) (uint32_t) ((int32_t) v59) * (uint32_t) v18) * (unsigned) v18 + v15 * (unsigned) v19), v60, v61); + TLOAD(v48, v62); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v49, v48); + pto::Shape<1, 1, 1, 128, 128> v63 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v64 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v65 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v15 + v15 * (unsigned) v19 + v15 * (unsigned) v18), v63, v64); + TLOAD(v50, v65); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TMOV(v52, v50); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + __cc__ float* v66 = v53.data(); + __cc__ float* v67 = v66 + (v15 + v15 * v14 + v15 * v13); + __cc__ float* v68 = (__cc__ float*) v67; + Tile v69; + uint64_t v70 = reinterpret_cast(v67); + TASSIGN(v69, v70); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(v69, v49, v52); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 128, 128> v71 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v72 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v73 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v15 + v15 * (unsigned) v19 + v13 * (unsigned) v18), v71, v72); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v50, v73); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v52, v50); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + __cc__ float* v74 = v53.data(); + __cc__ float* v75 = v74 + (v15 + v15 * v14 + v13 * v13); + __cc__ float* v76 = (__cc__ float*) v75; + Tile v77; + uint64_t v78 = reinterpret_cast(v75); + TASSIGN(v77, v78); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + TMATMUL(v77, v49, v52); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v79(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v79); + __gm__ float* v80 = PTOAS__GLOBAL_TENSOR_DATA(v79); + pto::Shape<1, 1, 1, 128, 256> v81 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<32768, 32768, 32768, 256, 1> v82 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v83 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v80, v81, v82); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TSTORE(v83, v53); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v79); + pto::Shape<1, 1, 1, 128, 128> v84 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v85 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v86 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v15 + v15 * (unsigned) v19 + v12 * (unsigned) v18), v84, v85); + TLOAD(v51, v86); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + TMOV(v52, v51); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + TMATMUL(v69, v49, v52); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + pto::Shape<1, 1, 1, 128, 128> v87 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v88 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v89 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v15 + v15 * (unsigned) v19 + v11 * (unsigned) v18), v87, v88); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v51, v89); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + TMOV(v52, v51); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + TMATMUL(v77, v49, v52); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v90(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v90); + __gm__ float* v91 = PTOAS__GLOBAL_TENSOR_DATA(v90); + pto::Shape<1, 1, 1, 128, 256> v92 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<32768, 32768, 32768, 256, 1> v93 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v94 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v91, v92, v93); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + TSTORE(v94, v53); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v90); + pto::Shape<1, 1, 1, 128, 128> v95 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v96 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v97 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v4 + (v15 + v15 * (unsigned) v18 + v15 * (unsigned) v19), v95, v96); + TLOAD(v56, v97); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + for (size_t v98 = (size_t) v20; v98 < ((size_t) v29); v98 += v30) { + int32_t v99 = (int32_t) ((uint32_t) ((int32_t) v98) * (uint32_t) v28); + int32_t v100 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v99 + (uint32_t) v28) * (uint32_t) v17); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v101(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v45, v101); + __gm__ half* v102 = PTOAS__GLOBAL_TENSOR_DATA(v101); + pto::Shape<1, 1, 1, 128, 128> v103 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<32768, 32768, 32768, 256, 1> v104 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v105 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v102, v103, v104); + TLOAD(v54, v105); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID5); + TMOV(v55, v54); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + int32_t v106 = (int32_t) ((uint32_t) v99 * (uint32_t) v17); + pto::Shape<1, 1, 1, 128, 128> v107 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v108 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v109 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v4 + (v15 + (unsigned) v106 * (unsigned) v18 + v15 * (unsigned) v19), v107, v108); + pipe_barrier(PIPE_MTE2); + TLOAD(v56, v109); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + TMOV(v57, v56); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + TMATMUL(v58, v55, v57); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + __gm__ half* v110 = PTOAS__GLOBAL_TENSOR_DATA(v101); + pto::Shape<1, 1, 1, 128, 128> v111 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<32768, 32768, 32768, 256, 1> v112 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v113 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v110 + v13, v111, v112); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v54, v113); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID7); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TMOV(v55, v54); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v45, v101); + pto::Shape<1, 1, 1, 128, 128> v114 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v115 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v116 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v4 + (v15 + (unsigned) ((int32_t) (uint32_t) v106 + (uint32_t) v18) * (unsigned) v18 + v15 * (unsigned) v19), v114, v115); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + TLOAD(v56, v116); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v57, v56); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + TMATMUL_ACC(v58, v58, v55, v57); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v117(nullptr); + TALLOC, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v47, v117); + __gm__ float* v118 = PTOAS__GLOBAL_TENSOR_DATA(v117); + pto::Shape<1, 1, 1, 128, 128> v119 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v120 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v121 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v118, v119, v120); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + TSTORE(v121, v58); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + TPUSH, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v47, v117); + pto::Shape<1, 1, 1, 128, 128> v122 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v123 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v124 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v15 + v15 * (unsigned) v19 + (unsigned) v100 * (unsigned) v18), v122, v123); + TLOAD(v50, v124); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TMOV(v52, v50); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + TMATMUL(v69, v49, v52); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 128, 128> v125 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v126 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v127 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v15 + v15 * (unsigned) v19 + (unsigned) ((int32_t) (uint32_t) v100 + (uint32_t) v18) * (unsigned) v18), v125, v126); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + TLOAD(v50, v127); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v52, v50); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + TMATMUL(v77, v49, v52); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v128(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v128); + __gm__ float* v129 = PTOAS__GLOBAL_TENSOR_DATA(v128); + pto::Shape<1, 1, 1, 128, 256> v130 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<32768, 32768, 32768, 256, 1> v131 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v132 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v129, v130, v131); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + TSTORE(v132, v53); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v128); + int32_t v133 = (int32_t) ((uint32_t) v99 + (uint32_t) v19); + int32_t v134 = (int32_t) ((uint32_t) ((int32_t) (uint32_t) v133 + (uint32_t) v28) * (uint32_t) v17); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v135(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v45, v135); + __gm__ half* v136 = PTOAS__GLOBAL_TENSOR_DATA(v135); + pto::Shape<1, 1, 1, 128, 128> v137 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<32768, 32768, 32768, 256, 1> v138 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v139 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v136, v137, v138); + TLOAD(v54, v139); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v55, v54); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + int32_t v140 = (int32_t) ((uint32_t) v133 * (uint32_t) v17); + pto::Shape<1, 1, 1, 128, 128> v141 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v142 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v143 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v4 + (v15 + (unsigned) v140 * (unsigned) v18 + v15 * (unsigned) v19), v141, v142); + TLOAD(v56, v143); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v57, v56); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + TMATMUL(v58, v55, v57); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + __gm__ half* v144 = PTOAS__GLOBAL_TENSOR_DATA(v135); + pto::Shape<1, 1, 1, 128, 128> v145 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<32768, 32768, 32768, 256, 1> v146 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v147 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v144 + v13, v145, v146); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + TLOAD(v54, v147); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v55, v54); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v45, v135); + pto::Shape<1, 1, 1, 128, 128> v148 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v149 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v150 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v4 + (v15 + (unsigned) ((int32_t) (uint32_t) v140 + (uint32_t) v18) * (unsigned) v18 + v15 * (unsigned) v19), v148, v149); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + TLOAD(v56, v150); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v57, v56); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v58, v58, v55, v57); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v151(nullptr); + TALLOC, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v47, v151); + __gm__ float* v152 = PTOAS__GLOBAL_TENSOR_DATA(v151); + pto::Shape<1, 1, 1, 128, 128> v153 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v154 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v155 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v152, v153, v154); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + TSTORE(v155, v58); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + TPUSH, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v47, v151); + pto::Shape<1, 1, 1, 128, 128> v156 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v157 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v158 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v15 + v15 * (unsigned) v19 + (unsigned) v134 * (unsigned) v18), v156, v157); + TLOAD(v51, v158); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v52, v51); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID7); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + TMATMUL(v69, v49, v52); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + pto::Shape<1, 1, 1, 128, 128> v159 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<128, 128, 128, 1, 128> v160 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v161 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v15 + v15 * (unsigned) v19 + (unsigned) ((int32_t) (uint32_t) v134 + (uint32_t) v18) * (unsigned) v18), v159, v160); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID7); + TLOAD(v51, v161); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v52, v51); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(v77, v49, v52); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v162(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v162); + __gm__ float* v163 = PTOAS__GLOBAL_TENSOR_DATA(v162); + pto::Shape<1, 1, 1, 128, 256> v164 = pto::Shape<1, 1, 1, 128, 256>(); + pto::Stride<32768, 32768, 32768, 256, 1> v165 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v166 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v163, v164, v165); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + TSTORE(v166, v53); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v43, v162); + }; + set_flag(PIPE_FIX, PIPE_M, EVENT_ID6); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v167(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v45, v167); + __gm__ half* v168 = PTOAS__GLOBAL_TENSOR_DATA(v167); + pto::Shape<1, 1, 1, 128, 128> v169 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<32768, 32768, 32768, 256, 1> v170 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v171 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v168, v169, v170); + TLOAD(v54, v171); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v55, v54); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + pto::Shape<1, 1, 1, 128, 128> v172 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v173 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v174 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v4 + (v15 + v10 * (unsigned) v18 + v15 * (unsigned) v19), v172, v173); + pipe_barrier(PIPE_MTE2); + TLOAD(v56, v174); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v57, v56); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID6); + TMATMUL(v58, v55, v57); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + __gm__ half* v175 = PTOAS__GLOBAL_TENSOR_DATA(v167); + pto::Shape<1, 1, 1, 128, 128> v176 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<32768, 32768, 32768, 256, 1> v177 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v178 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v175 + v13, v176, v177); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v54, v178); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v55, v54); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v45, v167); + pto::Shape<1, 1, 1, 128, 128> v179 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v180 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v181 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v4 + (v15 + v9 * (unsigned) v18 + v15 * (unsigned) v19), v179, v180); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v56, v181); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v57, v56); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v58, v58, v55, v57); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID6); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v182(nullptr); + TALLOC, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v47, v182); + __gm__ float* v183 = PTOAS__GLOBAL_TENSOR_DATA(v182); + pto::Shape<1, 1, 1, 128, 128> v184 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v185 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v186 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v183, v184, v185); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID6); + TSTORE(v186, v58); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID7); + TPUSH, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v47, v182); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v187(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v45, v187); + __gm__ half* v188 = PTOAS__GLOBAL_TENSOR_DATA(v187); + pto::Shape<1, 1, 1, 128, 128> v189 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<32768, 32768, 32768, 256, 1> v190 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v191 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v188, v189, v190); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v54, v191); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v55, v54); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + pto::Shape<1, 1, 1, 128, 128> v192 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v193 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v194 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v4 + (v15 + v8 * (unsigned) v18 + v15 * (unsigned) v19), v192, v193); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v56, v194); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v57, v56); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID7); + TMATMUL(v58, v55, v57); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + __gm__ half* v195 = PTOAS__GLOBAL_TENSOR_DATA(v187); + pto::Shape<1, 1, 1, 128, 128> v196 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<32768, 32768, 32768, 256, 1> v197 = pto::Stride<32768, 32768, 32768, 256, 1>(); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v198 = GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>(v195 + v13, v196, v197); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v54, v198); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v55, v54); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v45, v187); + pto::Shape<1, 1, 1, 128, 128> v199 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v200 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v201 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v4 + (v15 + v7 * (unsigned) v18 + v15 * (unsigned) v19), v199, v200); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + TLOAD(v56, v201); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v57, v56); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL_ACC(v58, v58, v55, v57); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID7); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v202(nullptr); + TALLOC, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v47, v202); + __gm__ float* v203 = PTOAS__GLOBAL_TENSOR_DATA(v202); + pto::Shape<1, 1, 1, 128, 128> v204 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v205 = pto::Stride<16384, 16384, 16384, 128, 1>(); + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND> v206 = GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>(v203, v204, v205); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID7); + TSTORE(v206, v58); + TPUSH, GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v47, v202); + } + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +AICORE void vector_kernel(__gm__ float* v1, __gm__ float* v2, __gm__ half* v3, __gm__ float* v4) { + RoundMode v5 = RoundMode::CAST_RINT; + unsigned v6 = 256; + unsigned v7 = 0; + const int32_t v8 = 0; + const int32_t v9 = 16; + const int32_t v10 = 32; + const int32_t v11 = 64; + const int32_t v12 = 128; + const int32_t v13 = 1; + const int32_t v14 = 262144; + const int32_t v15 = 131072; + const int64_t v16 = 196608; + const int64_t v17 = 262144; + const int64_t v18 = 327680; + const int64_t v19 = 360448; + const int64_t v20 = 393216; + const int64_t v21 = 393344; + const int64_t v22 = 393472; + const int64_t v23 = 393600; + const int64_t v24 = 393728; + const int64_t v25 = 393856; + const int64_t v26 = 393984; + const int64_t v27 = 394112; + const int64_t v28 = 394240; + const int64_t v29 = 394368; + const int64_t v30 = 394496; + const int64_t v31 = 394624; + const float v32 = 0.0883883461f; + const float v33 = 1.0f; + const int64_t v34 = 394752; + const int32_t v35 = 7; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + size_t v36 = (size_t) v13; + int64_t v37 = get_block_num(); + int32_t v38 = (int32_t) ((int64_t) v37); + int64_t v39 = get_block_idx(); + int32_t v40 = (int32_t) ((int64_t) v39); + int32_t v41 = v9 / v38; + int32_t v42 = v9 % v38; + int32_t v43 = (int32_t) ((uint32_t) v41 + (uint32_t) v13); + bool v44 = v40 < v42; + int32_t v45 = v44 ? (int32_t) ((uint32_t) v40 * (uint32_t) v43) : (int32_t) ((uint32_t) ((int32_t) (uint32_t) v42 * (uint32_t) v43) + (uint32_t) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v40 - (uint32_t) v42) * (uint32_t) v41)); + int32_t v46 = (int32_t) ((uint32_t) v40 * (uint32_t) v14); + int32_t v47 = (int32_t) ((uint32_t) v40 * (uint32_t) v15); + __gm__ float* v48 = v1 + v46; + + auto v49 = TPipe<0, Direction::DIR_C2V, 131072, 8, 8, true>(v48, v8, v8); + __gm__ half* v50 = v3 + v46; + + auto v51 = TPipe<2, Direction::DIR_V2C, 65536, 8, 8, true>(v50, v8, v8); + __gm__ float* v52 = v4 + v47; + + auto v53 = TPipe<4, Direction::DIR_C2V, 65536, 8, 8, false>(v52, v8, v8); + int64_t v54 = get_subblockid(); + int32_t v55 = (int32_t) ((uint32_t) ((int32_t) (int64_t) v54) * (uint32_t) v11); + int32_t v56 = (int32_t) ((uint32_t) v55 + (uint32_t) v10); + Tile v57; + TASSIGN(v57, v16); + Tile v58; + TASSIGN(v58, v17); + Tile v59; + TASSIGN(v59, v18); + Tile v60; + TASSIGN(v60, v19); + Tile v61; + TASSIGN(v61, v24); + Tile v62; + TASSIGN(v62, v28); + Tile v63; + TASSIGN(v63, v30); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + for (size_t v64 = (size_t) v45; v64 < ((size_t) ((int32_t) (uint32_t) v45 + (uint32_t) (v44 ? v43 : v41))); v64 += v36) { + Tile v65; + TASSIGN(v65, v34); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v66(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v66); + __gm__ float* v67 = PTOAS__GLOBAL_TENSOR_DATA(v66); + pto::Shape<1, 1, 1, 32, 256> v68 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v69 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v70 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v67 + (v7 + (unsigned) v55 * v6), v68, v69); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TLOAD(v65, v70); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v65, v65, v32); + Tile v71; + TASSIGN(v71, v20); + Tile v72; + TASSIGN(v72, v22); + Tile v73; + TASSIGN(v73, v24); + pipe_barrier(PIPE_V); + TROWMAX(v72, v65, v58); + Tile v74; + TRESHAPE(v74, v72); + Tile v75; + TRESHAPE(v75, v71); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(v58, v65, v72); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TMULS(v75, v74, v33); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + pipe_barrier(PIPE_V); + TROWSUM(v73, v58, v57); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v76(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v76); + __gm__ half* v77 = PTOAS__GLOBAL_TENSOR_DATA(v76); + pto::Shape<1, 1, 1, 32, 256> v78 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v79 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v80 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v77 + (v7 + (unsigned) v55 * v6), v78, v79); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v80, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + __gm__ float* v81 = PTOAS__GLOBAL_TENSOR_DATA(v66); + pto::Shape<1, 1, 1, 32, 256> v82 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v83 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v84 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v81 + (v7 + (unsigned) v56 * v6), v82, v83); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v65, v84); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v66); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TMULS(v65, v65, v32); + Tile v85; + TASSIGN(v85, v21); + Tile v86; + TASSIGN(v86, v23); + Tile v87; + TASSIGN(v87, v25); + pipe_barrier(PIPE_V); + TROWMAX(v86, v65, v58); + Tile v88; + TRESHAPE(v88, v86); + Tile v89; + TRESHAPE(v89, v85); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(v58, v65, v86); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TMULS(v89, v88, v33); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + pipe_barrier(PIPE_V); + TROWSUM(v87, v58, v57); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + __gm__ half* v90 = PTOAS__GLOBAL_TENSOR_DATA(v76); + pto::Shape<1, 1, 1, 32, 256> v91 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v92 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v93 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v90 + (v7 + (unsigned) v56 * v6), v91, v92); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TSTORE(v93, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v76); + Tile v94; + TASSIGN(v94, v34); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v95(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v95); + __gm__ float* v96 = PTOAS__GLOBAL_TENSOR_DATA(v95); + pto::Shape<1, 1, 1, 32, 256> v97 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v98 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v99 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v96 + (v7 + (unsigned) v55 * v6), v97, v98); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v94, v99); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + TMULS(v94, v94, v32); + Tile v100; + TASSIGN(v100, v20); + Tile v101; + TASSIGN(v101, v22); + Tile v102; + TASSIGN(v102, v24); + Tile v103; + TASSIGN(v103, v26); + Tile v104; + TASSIGN(v104, v30); + pipe_barrier(PIPE_V); + TROWMAX(v101, v94, v58); + Tile v105; + TRESHAPE(v105, v101); + Tile v106; + TRESHAPE(v106, v100); + Tile v107; + TRESHAPE(v107, v104); + Tile v108; + TRESHAPE(v108, v102); + Tile v109; + TRESHAPE(v109, v103); + pipe_barrier(PIPE_V); + TMAX(v105, v105, v106); + pipe_barrier(PIPE_V); + TSUB(v107, v106, v105); + pipe_barrier(PIPE_V); + TMULS(v106, v105, v33); + TROWEXPANDSUB(v58, v94, v101); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + TEXP(v107, v107); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + TMUL(v108, v108, v107); + pipe_barrier(PIPE_V); + TROWSUM(v103, v58, v57); + pipe_barrier(PIPE_V); + TADD(v108, v108, v109); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID2); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v110(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v110); + __gm__ half* v111 = PTOAS__GLOBAL_TENSOR_DATA(v110); + pto::Shape<1, 1, 1, 32, 256> v112 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v113 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v114 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v111 + (v7 + (unsigned) v55 * v6), v112, v113); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID2); + TSTORE(v114, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + __gm__ float* v115 = PTOAS__GLOBAL_TENSOR_DATA(v95); + pto::Shape<1, 1, 1, 32, 256> v116 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v117 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v118 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v115 + (v7 + (unsigned) v56 * v6), v116, v117); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + TLOAD(v94, v118); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v95); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + TMULS(v94, v94, v32); + Tile v119; + TASSIGN(v119, v21); + Tile v120; + TASSIGN(v120, v23); + Tile v121; + TASSIGN(v121, v25); + Tile v122; + TASSIGN(v122, v27); + Tile v123; + TASSIGN(v123, v31); + pipe_barrier(PIPE_V); + TROWMAX(v120, v94, v58); + Tile v124; + TRESHAPE(v124, v120); + Tile v125; + TRESHAPE(v125, v119); + Tile v126; + TRESHAPE(v126, v123); + Tile v127; + TRESHAPE(v127, v121); + Tile v128; + TRESHAPE(v128, v122); + pipe_barrier(PIPE_V); + TMAX(v124, v124, v125); + pipe_barrier(PIPE_V); + TSUB(v126, v125, v124); + pipe_barrier(PIPE_V); + TMULS(v125, v124, v33); + TROWEXPANDSUB(v58, v94, v120); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + TEXP(v126, v126); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + TMUL(v127, v127, v126); + pipe_barrier(PIPE_V); + TROWSUM(v122, v58, v57); + pipe_barrier(PIPE_V); + TADD(v127, v127, v128); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + __gm__ half* v129 = PTOAS__GLOBAL_TENSOR_DATA(v110); + pto::Shape<1, 1, 1, 32, 256> v130 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v131 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v132 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v129 + (v7 + (unsigned) v56 * v6), v130, v131); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + TSTORE(v132, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID4); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v110); + Tile v133; + TASSIGN(v133, v34); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v134(nullptr); + TPOP, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v134); + __gm__ float* v135 = PTOAS__GLOBAL_TENSOR_DATA(v134); + pto::Shape<1, 1, 1, 64, 128> v136 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v137 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v138 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v135, v136, v137); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + TLOAD(v133, v138); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + TMOV(v60, v133); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID5); + TFREE, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v134); + Tile v139; + TASSIGN(v139, v34); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v140(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v140); + __gm__ float* v141 = PTOAS__GLOBAL_TENSOR_DATA(v140); + pto::Shape<1, 1, 1, 32, 256> v142 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v143 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v144 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v141 + (v7 + (unsigned) v55 * v6), v142, v143); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID5); + TLOAD(v139, v144); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID5); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID5); + TMULS(v139, v139, v32); + Tile v145; + TASSIGN(v145, v20); + Tile v146; + TASSIGN(v146, v22); + Tile v147; + TASSIGN(v147, v24); + Tile v148; + TASSIGN(v148, v26); + Tile v149; + TASSIGN(v149, v28); + pipe_barrier(PIPE_V); + TROWMAX(v146, v139, v58); + Tile v150; + TRESHAPE(v150, v146); + Tile v151; + TRESHAPE(v151, v145); + Tile v152; + TRESHAPE(v152, v149); + Tile v153; + TRESHAPE(v153, v147); + Tile v154; + TRESHAPE(v154, v148); + pipe_barrier(PIPE_V); + TMAX(v150, v150, v151); + pipe_barrier(PIPE_V); + TSUB(v152, v151, v150); + pipe_barrier(PIPE_V); + TMULS(v151, v150, v33); + TROWEXPANDSUB(v58, v139, v146); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID6); + TEXP(v152, v152); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + TMUL(v153, v153, v152); + pipe_barrier(PIPE_V); + TROWSUM(v148, v58, v57); + pipe_barrier(PIPE_V); + TADD(v153, v153, v154); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID4); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID4); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v155(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v155); + __gm__ half* v156 = PTOAS__GLOBAL_TENSOR_DATA(v155); + pto::Shape<1, 1, 1, 32, 256> v157 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v158 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v159 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v156 + (v7 + (unsigned) v55 * v6), v157, v158); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID4); + TSTORE(v159, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + __gm__ float* v160 = PTOAS__GLOBAL_TENSOR_DATA(v140); + pto::Shape<1, 1, 1, 32, 256> v161 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v162 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v163 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v160 + (v7 + (unsigned) v56 * v6), v161, v162); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID6); + TLOAD(v139, v163); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID6); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v140); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID6); + TMULS(v139, v139, v32); + Tile v164; + TASSIGN(v164, v21); + Tile v165; + TASSIGN(v165, v23); + Tile v166; + TASSIGN(v166, v25); + Tile v167; + TASSIGN(v167, v27); + Tile v168; + TASSIGN(v168, v29); + pipe_barrier(PIPE_V); + TROWMAX(v165, v139, v58); + Tile v169; + TRESHAPE(v169, v165); + Tile v170; + TRESHAPE(v170, v164); + Tile v171; + TRESHAPE(v171, v168); + Tile v172; + TRESHAPE(v172, v166); + Tile v173; + TRESHAPE(v173, v167); + pipe_barrier(PIPE_V); + TMAX(v169, v169, v170); + pipe_barrier(PIPE_V); + TSUB(v171, v170, v169); + pipe_barrier(PIPE_V); + TMULS(v170, v169, v33); + TROWEXPANDSUB(v58, v139, v165); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + TEXP(v171, v171); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + TMUL(v172, v172, v171); + pipe_barrier(PIPE_V); + TROWSUM(v167, v58, v57); + pipe_barrier(PIPE_V); + TADD(v172, v172, v173); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID5); + __gm__ half* v174 = PTOAS__GLOBAL_TENSOR_DATA(v155); + pto::Shape<1, 1, 1, 32, 256> v175 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v176 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v177 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v174 + (v7 + (unsigned) v56 * v6), v175, v176); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID5); + TSTORE(v177, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v155); + Tile v178; + TASSIGN(v178, v34); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v179(nullptr); + TPOP, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v179); + __gm__ float* v180 = PTOAS__GLOBAL_TENSOR_DATA(v179); + pto::Shape<1, 1, 1, 64, 128> v181 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v182 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v183 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v180, v181, v182); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + TLOAD(v178, v183); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TROWEXPANDMUL(v60, v60, v63); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v60, v60, v178); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TFREE, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v179); + Tile v184; + TASSIGN(v184, v34); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v185(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v185); + __gm__ float* v186 = PTOAS__GLOBAL_TENSOR_DATA(v185); + pto::Shape<1, 1, 1, 32, 256> v187 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v188 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v189 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v186 + (v7 + (unsigned) v55 * v6), v187, v188); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v184, v189); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v184, v184, v32); + Tile v190; + TASSIGN(v190, v20); + Tile v191; + TASSIGN(v191, v22); + Tile v192; + TASSIGN(v192, v24); + Tile v193; + TASSIGN(v193, v26); + Tile v194; + TASSIGN(v194, v30); + pipe_barrier(PIPE_V); + TROWMAX(v191, v184, v58); + Tile v195; + TRESHAPE(v195, v191); + Tile v196; + TRESHAPE(v196, v190); + Tile v197; + TRESHAPE(v197, v194); + Tile v198; + TRESHAPE(v198, v192); + Tile v199; + TRESHAPE(v199, v193); + pipe_barrier(PIPE_V); + TMAX(v195, v195, v196); + pipe_barrier(PIPE_V); + TSUB(v197, v196, v195); + pipe_barrier(PIPE_V); + TMULS(v196, v195, v33); + TROWEXPANDSUB(v58, v184, v191); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TEXP(v197, v197); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + TMUL(v198, v198, v197); + pipe_barrier(PIPE_V); + TROWSUM(v193, v58, v57); + pipe_barrier(PIPE_V); + TADD(v198, v198, v199); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID6); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v200(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v200); + __gm__ half* v201 = PTOAS__GLOBAL_TENSOR_DATA(v200); + pto::Shape<1, 1, 1, 32, 256> v202 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v203 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v204 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v201 + (v7 + (unsigned) v55 * v6), v202, v203); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID6); + TSTORE(v204, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID7); + __gm__ float* v205 = PTOAS__GLOBAL_TENSOR_DATA(v185); + pto::Shape<1, 1, 1, 32, 256> v206 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v207 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v208 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v205 + (v7 + (unsigned) v56 * v6), v206, v207); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v184, v208); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v185); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v184, v184, v32); + Tile v209; + TASSIGN(v209, v21); + Tile v210; + TASSIGN(v210, v23); + Tile v211; + TASSIGN(v211, v25); + Tile v212; + TASSIGN(v212, v27); + Tile v213; + TASSIGN(v213, v31); + pipe_barrier(PIPE_V); + TROWMAX(v210, v184, v58); + Tile v214; + TRESHAPE(v214, v210); + Tile v215; + TRESHAPE(v215, v209); + Tile v216; + TRESHAPE(v216, v213); + Tile v217; + TRESHAPE(v217, v211); + Tile v218; + TRESHAPE(v218, v212); + pipe_barrier(PIPE_V); + TMAX(v214, v214, v215); + pipe_barrier(PIPE_V); + TSUB(v216, v215, v214); + pipe_barrier(PIPE_V); + TMULS(v215, v214, v33); + TROWEXPANDSUB(v58, v184, v210); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TEXP(v216, v216); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + TMUL(v217, v217, v216); + pipe_barrier(PIPE_V); + TROWSUM(v212, v58, v57); + pipe_barrier(PIPE_V); + TADD(v217, v217, v218); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID7); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID7); + __gm__ half* v219 = PTOAS__GLOBAL_TENSOR_DATA(v200); + pto::Shape<1, 1, 1, 32, 256> v220 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v221 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v222 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v219 + (v7 + (unsigned) v56 * v6), v220, v221); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID7); + TSTORE(v222, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v200); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + for (size_t v223 = v36; v223 < ((size_t) v35); v223 += v36) { + Tile v224; + TASSIGN(v224, v34); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v225(nullptr); + TPOP, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v225); + __gm__ float* v226 = PTOAS__GLOBAL_TENSOR_DATA(v225); + pto::Shape<1, 1, 1, 64, 128> v227 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v228 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v229 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v226, v227, v228); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v224, v229); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TROWEXPANDMUL(v60, v60, v62); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v60, v60, v224); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TFREE, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v225); + Tile v230; + TASSIGN(v230, v34); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v231(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v231); + __gm__ float* v232 = PTOAS__GLOBAL_TENSOR_DATA(v231); + pto::Shape<1, 1, 1, 32, 256> v233 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v234 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v235 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v232 + (v7 + (unsigned) v55 * v6), v233, v234); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v230, v235); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v230, v230, v32); + Tile v236; + TASSIGN(v236, v20); + Tile v237; + TASSIGN(v237, v22); + Tile v238; + TASSIGN(v238, v24); + Tile v239; + TASSIGN(v239, v26); + Tile v240; + TASSIGN(v240, v28); + pipe_barrier(PIPE_V); + TROWMAX(v237, v230, v58); + Tile v241; + TRESHAPE(v241, v237); + Tile v242; + TRESHAPE(v242, v236); + Tile v243; + TRESHAPE(v243, v240); + Tile v244; + TRESHAPE(v244, v238); + Tile v245; + TRESHAPE(v245, v239); + pipe_barrier(PIPE_V); + TMAX(v241, v241, v242); + pipe_barrier(PIPE_V); + TSUB(v243, v242, v241); + pipe_barrier(PIPE_V); + TMULS(v242, v241, v33); + TROWEXPANDSUB(v58, v230, v237); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TEXP(v243, v243); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + TMUL(v244, v244, v243); + pipe_barrier(PIPE_V); + TROWSUM(v239, v58, v57); + pipe_barrier(PIPE_V); + TADD(v244, v244, v245); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v246(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v246); + __gm__ half* v247 = PTOAS__GLOBAL_TENSOR_DATA(v246); + pto::Shape<1, 1, 1, 32, 256> v248 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v249 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v250 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v247 + (v7 + (unsigned) v55 * v6), v248, v249); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v250, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + __gm__ float* v251 = PTOAS__GLOBAL_TENSOR_DATA(v231); + pto::Shape<1, 1, 1, 32, 256> v252 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v253 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v254 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v251 + (v7 + (unsigned) v56 * v6), v252, v253); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v230, v254); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v231); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v230, v230, v32); + Tile v255; + TASSIGN(v255, v21); + Tile v256; + TASSIGN(v256, v23); + Tile v257; + TASSIGN(v257, v25); + Tile v258; + TASSIGN(v258, v27); + Tile v259; + TASSIGN(v259, v29); + pipe_barrier(PIPE_V); + TROWMAX(v256, v230, v58); + Tile v260; + TRESHAPE(v260, v256); + Tile v261; + TRESHAPE(v261, v255); + Tile v262; + TRESHAPE(v262, v259); + Tile v263; + TRESHAPE(v263, v257); + Tile v264; + TRESHAPE(v264, v258); + pipe_barrier(PIPE_V); + TMAX(v260, v260, v261); + pipe_barrier(PIPE_V); + TSUB(v262, v261, v260); + pipe_barrier(PIPE_V); + TMULS(v261, v260, v33); + TROWEXPANDSUB(v58, v230, v256); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TEXP(v262, v262); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + TMUL(v263, v263, v262); + pipe_barrier(PIPE_V); + TROWSUM(v258, v58, v57); + pipe_barrier(PIPE_V); + TADD(v263, v263, v264); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + __gm__ half* v265 = PTOAS__GLOBAL_TENSOR_DATA(v246); + pto::Shape<1, 1, 1, 32, 256> v266 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v267 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v268 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v265 + (v7 + (unsigned) v56 * v6), v266, v267); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v268, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v246); + Tile v269; + TASSIGN(v269, v34); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v270(nullptr); + TPOP, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v270); + __gm__ float* v271 = PTOAS__GLOBAL_TENSOR_DATA(v270); + pto::Shape<1, 1, 1, 64, 128> v272 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v273 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v274 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v271, v272, v273); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v269, v274); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TROWEXPANDMUL(v60, v60, v63); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v60, v60, v269); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TFREE, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v270); + Tile v275; + TASSIGN(v275, v34); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v276(nullptr); + TPOP, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v276); + __gm__ float* v277 = PTOAS__GLOBAL_TENSOR_DATA(v276); + pto::Shape<1, 1, 1, 32, 256> v278 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v279 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v280 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v277 + (v7 + (unsigned) v55 * v6), v278, v279); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v275, v280); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v275, v275, v32); + Tile v281; + TASSIGN(v281, v20); + Tile v282; + TASSIGN(v282, v22); + Tile v283; + TASSIGN(v283, v24); + Tile v284; + TASSIGN(v284, v26); + Tile v285; + TASSIGN(v285, v30); + pipe_barrier(PIPE_V); + TROWMAX(v282, v275, v58); + Tile v286; + TRESHAPE(v286, v282); + Tile v287; + TRESHAPE(v287, v281); + Tile v288; + TRESHAPE(v288, v285); + Tile v289; + TRESHAPE(v289, v283); + Tile v290; + TRESHAPE(v290, v284); + pipe_barrier(PIPE_V); + TMAX(v286, v286, v287); + pipe_barrier(PIPE_V); + TSUB(v288, v287, v286); + pipe_barrier(PIPE_V); + TMULS(v287, v286, v33); + TROWEXPANDSUB(v58, v275, v282); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TEXP(v288, v288); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + TMUL(v289, v289, v288); + pipe_barrier(PIPE_V); + TROWSUM(v284, v58, v57); + pipe_barrier(PIPE_V); + TADD(v289, v289, v290); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND> v291(nullptr); + TALLOC, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v291); + __gm__ half* v292 = PTOAS__GLOBAL_TENSOR_DATA(v291); + pto::Shape<1, 1, 1, 32, 256> v293 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v294 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v295 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v292 + (v7 + (unsigned) v55 * v6), v293, v294); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v295, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + __gm__ float* v296 = PTOAS__GLOBAL_TENSOR_DATA(v276); + pto::Shape<1, 1, 1, 32, 256> v297 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v298 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v299 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v296 + (v7 + (unsigned) v56 * v6), v297, v298); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TLOAD(v275, v299); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TFREE, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v49, v276); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v275, v275, v32); + Tile v300; + TASSIGN(v300, v21); + Tile v301; + TASSIGN(v301, v23); + Tile v302; + TASSIGN(v302, v25); + Tile v303; + TASSIGN(v303, v27); + Tile v304; + TASSIGN(v304, v31); + pipe_barrier(PIPE_V); + TROWMAX(v301, v275, v58); + Tile v305; + TRESHAPE(v305, v301); + Tile v306; + TRESHAPE(v306, v300); + Tile v307; + TRESHAPE(v307, v304); + Tile v308; + TRESHAPE(v308, v302); + Tile v309; + TRESHAPE(v309, v303); + pipe_barrier(PIPE_V); + TMAX(v305, v305, v306); + pipe_barrier(PIPE_V); + TSUB(v307, v306, v305); + pipe_barrier(PIPE_V); + TMULS(v306, v305, v33); + TROWEXPANDSUB(v58, v275, v301); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TEXP(v307, v307); + pipe_barrier(PIPE_V); + TEXP(v58, v58); + TMUL(v308, v308, v307); + pipe_barrier(PIPE_V); + TROWSUM(v303, v58, v57); + pipe_barrier(PIPE_V); + TADD(v308, v308, v309); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TCVT(v59, v58, v5); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + __gm__ half* v310 = PTOAS__GLOBAL_TENSOR_DATA(v291); + pto::Shape<1, 1, 1, 32, 256> v311 = pto::Shape<1, 1, 1, 32, 256>(); + pto::Stride<8192, 8192, 8192, 256, 1> v312 = pto::Stride<8192, 8192, 8192, 256, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND> v313 = GlobalTensor, pto::Stride<8192, 8192, 8192, 256, 1>, pto::Layout::ND>(v310 + (v7 + (unsigned) v56 * v6), v311, v312); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v313, v59); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + TPUSH, GlobalTensor, pto::Stride<32768, 32768, 32768, 256, 1>, pto::Layout::ND>, TileSplitAxis::TILE_NO_SPLIT>(v51, v291); + }; + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + Tile v314; + TASSIGN(v314, v34); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v315(nullptr); + TPOP, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v315); + __gm__ float* v316 = PTOAS__GLOBAL_TENSOR_DATA(v315); + pto::Shape<1, 1, 1, 64, 128> v317 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v318 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v319 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v316, v317, v318); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v314, v319); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TROWEXPANDMUL(v60, v60, v62); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v60, v60, v314); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TFREE, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v315); + Tile v320; + TASSIGN(v320, v34); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v321(nullptr); + TPOP, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v321); + __gm__ float* v322 = PTOAS__GLOBAL_TENSOR_DATA(v321); + pto::Shape<1, 1, 1, 64, 128> v323 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v324 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v325 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v322, v323, v324); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TLOAD(v320, v325); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TROWEXPANDMUL(v60, v60, v63); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v60, v60, v320); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TFREE, GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>, TileSplitAxis::TILE_UP_DOWN>(v53, v321); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v60, v60, v61); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pto::Shape<1, 1, 1, 64, 128> v326 = pto::Shape<1, 1, 1, 64, 128>(); + pto::Stride<8192, 8192, 8192, 128, 1> v327 = pto::Stride<8192, 8192, 8192, 128, 1>(); + GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND> v328 = GlobalTensor, pto::Stride<8192, 8192, 8192, 128, 1>, pto::Layout::ND>(v2 + (v7 + (unsigned) ((int32_t) (uint32_t) ((int32_t) (uint32_t) ((int32_t) v64) * (uint32_t) v12) + (uint32_t) v55) * (unsigned) v12 + v7 * (unsigned) v13), v326, v327); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TSTORE(v328, v60); + } + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +__global__ AICORE void call_both(__gm__ int64_t* v1, __gm__ half* v2, __gm__ half* v3, __gm__ half* v4, __gm__ half* v5, __gm__ float* v6, __gm__ float* v7, __gm__ float* v8) { + using T = float; + uint64_t v9 = (uint64_t) v1; + set_ffts_base_addr(v9); + cube_kernel(v7, v2, v3, v4, v5, v8); + vector_kernel(v7, v6, v5, v8); + return; +} diff --git a/examples/aot/flash_attention/ir_ref/fa_perf.pto b/examples/aot/flash_attention/ir_ref/fa_perf.pto new file mode 100644 index 00000000..1184907f --- /dev/null +++ b/examples/aot/flash_attention/ir_ref/fa_perf.pto @@ -0,0 +1,743 @@ +// RUN: ptoas --pto-arch=a3 --pto-level=level3 --enable-insert-sync %s >/dev/null + +module { + func.func @cube_kernel(%qk_fifo: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %p_fifo: !pto.ptr, %pv_fifo: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c128_0 = arith.constant 128 : index + %c256 = arith.constant 256 : index + %c4096 = arith.constant 4096 : index + %c128_1 = arith.constant 128 : index + %c16 = arith.constant 16 : index + %c16_2 = arith.constant 16 : index + %0 = pto.get_block_num + %1 = arith.index_cast %0 : i64 to index + %2 = pto.get_block_idx + %3 = arith.index_cast %2 : i64 to index + %4 = arith.divsi %c16_2, %1 : index + %5 = arith.remsi %c16_2, %1 : index + %6 = arith.addi %4, %c1 : index + %7 = arith.muli %3, %6 : index + %8 = arith.addi %4, %c1 : index + %9 = arith.muli %5, %8 : index + %10 = arith.subi %3, %5 : index + %11 = arith.muli %10, %4 : index + %12 = arith.addi %9, %11 : index + %13 = arith.cmpi slt, %3, %5 : index + %14 = arith.select %13, %7, %12 : index + %15 = arith.cmpi slt, %3, %5 : index + %16 = arith.addi %4, %c1 : index + %17 = arith.select %15, %16, %4 : index + %18 = arith.addi %14, %17 : index + %c262144 = arith.constant 262144 : index + %19 = arith.muli %3, %c262144 : index + %21 = pto.addptr %qk_fifo, %19 : -> + %p_block = pto.addptr %p_fifo, %19 : -> + %c131072 = arith.constant 131072 : index + %pv_block_off = arith.muli %3, %c131072 : index + %22 = pto.addptr %pv_fifo, %pv_block_off : -> + %qk_slot_desc = pto.make_tensor_view %21, shape = [%c128, %c256], strides = [%c256, %c1] : !pto.tensor_view<128x256xf32> + pto.aic_initialize_pipe{id = 25, dir_mask = 1, slot_size = 131072} (gm_slot_tensor = %qk_slot_desc : !pto.tensor_view<128x256xf32>) + %p_slot_desc = pto.make_tensor_view %p_block, shape = [%c128, %c256], strides = [%c256, %c1] : !pto.tensor_view<128x256xf16> + pto.aic_initialize_pipe{id = 30, dir_mask = 2, slot_size = 65536} (gm_slot_tensor = %p_slot_desc : !pto.tensor_view<128x256xf16>) + %pv_slot_desc = pto.make_tensor_view %22, shape = [%c128, %c128_0], strides = [%c128_0, %c1] : !pto.tensor_view<128x128xf32> + pto.aic_initialize_pipe{id = 27, dir_mask = 1, slot_size = 65536} (gm_slot_tensor = %pv_slot_desc : !pto.tensor_view<128x128xf32>) + %c0_i64 = arith.constant 0 : i64 + %c0_i64_4 = arith.constant 0 : i64 + %29 = pto.alloc_tile addr = %c0_i64_4 : !pto.tile_buf + %c0_i64_5 = arith.constant 0 : i64 + %30 = pto.alloc_tile addr = %c0_i64_5 : !pto.tile_buf + %c32768_i64 = arith.constant 32768 : i64 + %31 = pto.alloc_tile addr = %c32768_i64 : !pto.tile_buf + %c65536_i64 = arith.constant 65536 : i64 + %32 = pto.alloc_tile addr = %c65536_i64 : !pto.tile_buf + %33 = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %c0_i64_6 = arith.constant 0 : i64 + %34 = pto.alloc_tile addr = %c0_i64_6 : !pto.tile_buf + %c98304_i64 = arith.constant 98304 : i64 + %35 = pto.alloc_tile addr = %c98304_i64 : !pto.tile_buf + %c32768_i64_7 = arith.constant 32768 : i64 + %36 = pto.alloc_tile addr = %c32768_i64_7 : !pto.tile_buf + %c229376_i64 = arith.constant 229376 : i64 + %37 = pto.alloc_tile addr = %c229376_i64 : !pto.tile_buf + %38 = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %c131072_i64 = arith.constant 131072 : i64 + %39 = pto.alloc_tile addr = %c65536_i64 : !pto.tile_buf + %c2048 = arith.constant 2048 : index + %40 = pto.make_tensor_view %arg1, shape = [%c2048, %c128_0], strides = [%c128_0, %c1] : !pto.tensor_view + %41 = pto.make_tensor_view %arg2, shape = [%c128_0, %c4096], strides = [%c1, %c128_0] : !pto.tensor_view + %42 = pto.make_tensor_view %arg3, shape = [%c4096, %c128_0], strides = [%c128_0, %c1] : !pto.tensor_view + scf.for %arg4 = %14 to %18 step %c1 { + %43 = arith.muli %arg4, %c128 : index + %44 = pto.partition_view %40, offsets = [%43, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%44 : !pto.partition_tensor_view<128x128xf16>) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%30 : !pto.tile_buf) + %c0_8 = arith.constant 0 : index + %c0_9 = arith.constant 0 : index + %45 = arith.addi %c0_8, %c0_9 : index + %46 = pto.partition_view %41, offsets = [%c0, %45], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%46 : !pto.partition_tensor_view<128x128xf16>) outs(%31 : !pto.tile_buf) + pto.tmov ins(%31 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c0_10 = arith.constant 0 : index + %47 = pto.subview %34[%c0, %c0_10] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%47 : !pto.tile_buf) + %c128_11 = arith.constant 128 : index + %48 = arith.addi %c0_8, %c128_11 : index + %49 = pto.partition_view %41, offsets = [%c0, %48], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%49 : !pto.partition_tensor_view<128x128xf16>) outs(%31 : !pto.tile_buf) + pto.tmov ins(%31 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c128_12 = arith.constant 128 : index + %50 = pto.subview %34[%c0, %c128_12] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%50 : !pto.tile_buf) + %qk_push_0 = pto.talloc_to_aiv {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_push_part_0 = pto.partition_view %qk_push_0, offsets = [%c0, %c0], sizes = [%c128, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<128x256xf32> + pto.tstore ins(%34 : !pto.tile_buf) outs(%qk_push_part_0 : !pto.partition_tensor_view<128x256xf32>) + pto.tpush_to_aiv(%qk_push_0 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + %c256_13 = arith.constant 256 : index + %c0_14 = arith.constant 0 : index + %51 = arith.addi %c256_13, %c0_14 : index + %52 = pto.partition_view %41, offsets = [%c0, %51], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%52 : !pto.partition_tensor_view<128x128xf16>) outs(%32 : !pto.tile_buf) + pto.tmov ins(%32 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c0_15 = arith.constant 0 : index + %53 = pto.subview %34[%c0, %c0_15] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%53 : !pto.tile_buf) + %c128_16 = arith.constant 128 : index + %54 = arith.addi %c256_13, %c128_16 : index + %55 = pto.partition_view %41, offsets = [%c0, %54], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%55 : !pto.partition_tensor_view<128x128xf16>) outs(%32 : !pto.tile_buf) + pto.tmov ins(%32 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c128_17 = arith.constant 128 : index + %56 = pto.subview %34[%c0, %c128_17] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%56 : !pto.tile_buf) + %qk_push_1 = pto.talloc_to_aiv {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_push_part_1 = pto.partition_view %qk_push_1, offsets = [%c0, %c0], sizes = [%c128, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<128x256xf32> + pto.tstore ins(%34 : !pto.tile_buf) outs(%qk_push_part_1 : !pto.partition_tensor_view<128x256xf32>) + pto.tpush_to_aiv(%qk_push_1 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + %57 = pto.partition_view %42, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%57 : !pto.partition_tensor_view<128x128xf16>) outs(%37 : !pto.tile_buf) + %c2 = arith.constant 2 : index + %c7 = arith.constant 7 : index + scf.for %arg5 = %c0 to %c7 step %c1 { + %61 = arith.muli %arg5, %c2 : index + %c2_18 = arith.constant 2 : index + %62 = arith.addi %61, %c2_18 : index + %63 = arith.muli %62, %c256 : index + %p_pop_0 = pto.tpop_from_aiv {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_pop_part_0 = pto.partition_view %p_pop_0, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%p_pop_part_0 : !pto.partition_tensor_view<128x128xf16>) outs(%35 : !pto.tile_buf) + pto.tmov ins(%35 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + %pv0_v_base = arith.muli %61, %c256 : index + %pv0_v_part_0 = pto.partition_view %42, offsets = [%pv0_v_base, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%pv0_v_part_0 : !pto.partition_tensor_view<128x128xf16>) outs(%37 : !pto.tile_buf) + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tmatmul ins(%36, %38 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %p_pop_part_0_hi = pto.partition_view %p_pop_0, offsets = [%c0, %c128], sizes = [%c128, %c128_0] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%p_pop_part_0_hi : !pto.partition_tensor_view<128x128xf16>) outs(%35 : !pto.tile_buf) + pto.tmov ins(%35 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + pto.tfree_from_aiv(%p_pop_0 : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + %pv0_v_hi = arith.addi %pv0_v_base, %c128 : index + %pv0_v_part_1 = pto.partition_view %42, offsets = [%pv0_v_hi, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%pv0_v_part_1 : !pto.partition_tensor_view<128x128xf16>) outs(%37 : !pto.tile_buf) + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %pv_push_0 = pto.talloc_to_aiv {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_push_part_0 = pto.partition_view %pv_push_0, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<128x128xf32> + pto.tstore ins(%39 : !pto.tile_buf) outs(%pv_push_part_0 : !pto.partition_tensor_view<128x128xf32>) + pto.tpush_to_aiv(%pv_push_0 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %c0_19 = arith.constant 0 : index + %68 = arith.addi %63, %c0_19 : index + %69 = pto.partition_view %41, offsets = [%c0, %68], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%69 : !pto.partition_tensor_view<128x128xf16>) outs(%31 : !pto.tile_buf) + pto.tmov ins(%31 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c0_20 = arith.constant 0 : index + %70 = pto.subview %34[%c0, %c0_20] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%70 : !pto.tile_buf) + %c128_21 = arith.constant 128 : index + %71 = arith.addi %63, %c128_21 : index + %72 = pto.partition_view %41, offsets = [%c0, %71], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%72 : !pto.partition_tensor_view<128x128xf16>) outs(%31 : !pto.tile_buf) + pto.tmov ins(%31 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c128_22 = arith.constant 128 : index + %73 = pto.subview %34[%c0, %c128_22] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%73 : !pto.tile_buf) + %qk_push_2 = pto.talloc_to_aiv {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_push_part_2 = pto.partition_view %qk_push_2, offsets = [%c0, %c0], sizes = [%c128, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<128x256xf32> + pto.tstore ins(%34 : !pto.tile_buf) outs(%qk_push_part_2 : !pto.partition_tensor_view<128x256xf32>) + pto.tpush_to_aiv(%qk_push_2 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + %74 = arith.muli %arg5, %c2 : index + %75 = arith.addi %74, %c1 : index + %c2_23 = arith.constant 2 : index + %76 = arith.addi %75, %c2_23 : index + %77 = arith.muli %76, %c256 : index + %p_pop_1 = pto.tpop_from_aiv {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_pop_part_1 = pto.partition_view %p_pop_1, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%p_pop_part_1 : !pto.partition_tensor_view<128x128xf16>) outs(%35 : !pto.tile_buf) + pto.tmov ins(%35 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + %pv1_v_base = arith.muli %75, %c256 : index + %pv1_v_part_0 = pto.partition_view %42, offsets = [%pv1_v_base, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%pv1_v_part_0 : !pto.partition_tensor_view<128x128xf16>) outs(%37 : !pto.tile_buf) + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tmatmul ins(%36, %38 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %p_pop_part_1_hi = pto.partition_view %p_pop_1, offsets = [%c0, %c128], sizes = [%c128, %c128_0] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%p_pop_part_1_hi : !pto.partition_tensor_view<128x128xf16>) outs(%35 : !pto.tile_buf) + pto.tmov ins(%35 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + pto.tfree_from_aiv(%p_pop_1 : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + %pv1_v_hi = arith.addi %pv1_v_base, %c128 : index + %pv1_v_part_1 = pto.partition_view %42, offsets = [%pv1_v_hi, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%pv1_v_part_1 : !pto.partition_tensor_view<128x128xf16>) outs(%37 : !pto.tile_buf) + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %pv_push_1 = pto.talloc_to_aiv {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_push_part_1 = pto.partition_view %pv_push_1, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<128x128xf32> + pto.tstore ins(%39 : !pto.tile_buf) outs(%pv_push_part_1 : !pto.partition_tensor_view<128x128xf32>) + pto.tpush_to_aiv(%pv_push_1 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %c0_24 = arith.constant 0 : index + %82 = arith.addi %77, %c0_24 : index + %83 = pto.partition_view %41, offsets = [%c0, %82], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%83 : !pto.partition_tensor_view<128x128xf16>) outs(%32 : !pto.tile_buf) + pto.tmov ins(%32 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c0_25 = arith.constant 0 : index + %84 = pto.subview %34[%c0, %c0_25] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%84 : !pto.tile_buf) + %c128_26 = arith.constant 128 : index + %85 = arith.addi %77, %c128_26 : index + %86 = pto.partition_view %41, offsets = [%c0, %85], sizes = [%c128_0, %c128_1] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%86 : !pto.partition_tensor_view<128x128xf16>) outs(%32 : !pto.tile_buf) + pto.tmov ins(%32 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + %c128_27 = arith.constant 128 : index + %87 = pto.subview %34[%c0, %c128_27] sizes [128, 128] : !pto.tile_buf -> !pto.tile_buf + pto.tmatmul ins(%30, %33 : !pto.tile_buf, !pto.tile_buf) outs(%87 : !pto.tile_buf) + %qk_push_3 = pto.talloc_to_aiv {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_push_part_3 = pto.partition_view %qk_push_3, offsets = [%c0, %c0], sizes = [%c128, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<128x256xf32> + pto.tstore ins(%34 : !pto.tile_buf) outs(%qk_push_part_3 : !pto.partition_tensor_view<128x256xf32>) + pto.tpush_to_aiv(%qk_push_3 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + } + %p_pop_2 = pto.tpop_from_aiv {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_pop_part_2 = pto.partition_view %p_pop_2, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%p_pop_part_2 : !pto.partition_tensor_view<128x128xf16>) outs(%35 : !pto.tile_buf) + pto.tmov ins(%35 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + %c3584 = arith.constant 3584 : index + %pv2_v_part_0 = pto.partition_view %42, offsets = [%c3584, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%pv2_v_part_0 : !pto.partition_tensor_view<128x128xf16>) outs(%37 : !pto.tile_buf) + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tmatmul ins(%36, %38 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %p_pop_part_2_hi = pto.partition_view %p_pop_2, offsets = [%c0, %c128], sizes = [%c128, %c128_0] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%p_pop_part_2_hi : !pto.partition_tensor_view<128x128xf16>) outs(%35 : !pto.tile_buf) + pto.tmov ins(%35 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + pto.tfree_from_aiv(%p_pop_2 : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + %c3712 = arith.constant 3712 : index + %pv2_v_part_1 = pto.partition_view %42, offsets = [%c3712, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%pv2_v_part_1 : !pto.partition_tensor_view<128x128xf16>) outs(%37 : !pto.tile_buf) + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %pv_push_2 = pto.talloc_to_aiv {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_push_part_2 = pto.partition_view %pv_push_2, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<128x128xf32> + pto.tstore ins(%39 : !pto.tile_buf) outs(%pv_push_part_2 : !pto.partition_tensor_view<128x128xf32>) + pto.tpush_to_aiv(%pv_push_2 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + %p_pop_3 = pto.tpop_from_aiv {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_pop_part_3 = pto.partition_view %p_pop_3, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%p_pop_part_3 : !pto.partition_tensor_view<128x128xf16>) outs(%35 : !pto.tile_buf) + pto.tmov ins(%35 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + %c3840 = arith.constant 3840 : index + %pv3_v_part_0 = pto.partition_view %42, offsets = [%c3840, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%pv3_v_part_0 : !pto.partition_tensor_view<128x128xf16>) outs(%37 : !pto.tile_buf) + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tmatmul ins(%36, %38 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %p_pop_part_3_hi = pto.partition_view %p_pop_3, offsets = [%c0, %c128], sizes = [%c128, %c128_0] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%p_pop_part_3_hi : !pto.partition_tensor_view<128x128xf16>) outs(%35 : !pto.tile_buf) + pto.tmov ins(%35 : !pto.tile_buf) outs(%36 : !pto.tile_buf) + pto.tfree_from_aiv(%p_pop_3 : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + %c3968 = arith.constant 3968 : index + %pv3_v_part_1 = pto.partition_view %42, offsets = [%c3968, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<128x128xf16> + pto.tload ins(%pv3_v_part_1 : !pto.partition_tensor_view<128x128xf16>) outs(%37 : !pto.tile_buf) + pto.tmov ins(%37 : !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tmatmul.acc ins(%39, %36, %38 : !pto.tile_buf, !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + %pv_push_3 = pto.talloc_to_aiv {id = 27, split = 0} -> !pto.tensor_view<128x128xf32> + %pv_push_part_3 = pto.partition_view %pv_push_3, offsets = [%c0, %c0], sizes = [%c128, %c128_0] : !pto.tensor_view<128x128xf32> -> !pto.partition_tensor_view<128x128xf32> + pto.tstore ins(%39 : !pto.tile_buf) outs(%pv_push_part_3 : !pto.partition_tensor_view<128x128xf32>) + pto.tpush_to_aiv(%pv_push_3 : !pto.tensor_view<128x128xf32>) {id = 27, split = 0} + } + return + } + func.func @vector_kernel(%qk_fifo: !pto.ptr, %arg1: !pto.ptr, %p_fifo: !pto.ptr, %pv_fifo: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c128 = arith.constant 128 : index + %c64 = arith.constant 64 : index + %c32 = arith.constant 32 : index + %c256 = arith.constant 256 : index + %c128_0 = arith.constant 128 : index + %c16 = arith.constant 16 : index + %c16_1 = arith.constant 16 : index + %0 = pto.get_block_num + %1 = arith.index_cast %0 : i64 to index + %2 = pto.get_block_idx + %3 = arith.index_cast %2 : i64 to index + %4 = arith.divsi %c16_1, %1 : index + %5 = arith.remsi %c16_1, %1 : index + %6 = arith.addi %4, %c1 : index + %7 = arith.muli %3, %6 : index + %8 = arith.addi %4, %c1 : index + %9 = arith.muli %5, %8 : index + %10 = arith.subi %3, %5 : index + %11 = arith.muli %10, %4 : index + %12 = arith.addi %9, %11 : index + %13 = arith.cmpi slt, %3, %5 : index + %14 = arith.select %13, %7, %12 : index + %15 = arith.cmpi slt, %3, %5 : index + %16 = arith.addi %4, %c1 : index + %17 = arith.select %15, %16, %4 : index + %18 = arith.addi %14, %17 : index + %c262144 = arith.constant 262144 : index + %19 = arith.muli %3, %c262144 : index + %21 = pto.addptr %qk_fifo, %19 : -> + %p_block = pto.addptr %p_fifo, %19 : -> + %c131072 = arith.constant 131072 : index + %pv_block_off = arith.muli %3, %c131072 : index + %22 = pto.addptr %pv_fifo, %pv_block_off : -> + %qk_slot_desc = pto.make_tensor_view %21, shape = [%c128, %c256], strides = [%c256, %c1] : !pto.tensor_view<128x256xf32> + pto.aiv_initialize_pipe{id = 25, dir_mask = 1, slot_size = 131072} (gm_slot_tensor = %qk_slot_desc : !pto.tensor_view<128x256xf32>) + %p_slot_desc = pto.make_tensor_view %p_block, shape = [%c128, %c256], strides = [%c256, %c1] : !pto.tensor_view<128x256xf16> + pto.aiv_initialize_pipe{id = 30, dir_mask = 2, slot_size = 65536} (gm_slot_tensor = %p_slot_desc : !pto.tensor_view<128x256xf16>) + %pv_slot_desc = pto.make_tensor_view %22, shape = [%c64, %c128_0], strides = [%c128_0, %c1] : !pto.tensor_view<64x128xf32> + pto.aiv_initialize_pipe{id = 27, dir_mask = 1, slot_size = 65536} (gm_slot_tensor = %pv_slot_desc : !pto.tensor_view<64x128xf32>) + %29 = pto.get_subblock_idx + %30 = arith.index_cast %29 : i64 to index + %31 = arith.muli %30, %c64 : index + %row_slice_1 = arith.addi %31, %c32 : index + %c196608_i64 = arith.constant 196608 : i64 + %32 = pto.alloc_tile addr = %c196608_i64 : !pto.tile_buf + %c262144_i64 = arith.constant 262144 : i64 + %33 = pto.alloc_tile addr = %c262144_i64 : !pto.tile_buf + %c327680_i64 = arith.constant 327680 : i64 + %34 = pto.alloc_tile addr = %c327680_i64 : !pto.tile_buf + %c360448_i64 = arith.constant 360448 : i64 + %35 = pto.alloc_tile addr = %c360448_i64 : !pto.tile_buf + %c393216_i64 = arith.constant 393216 : i64 + %c393344_i64 = arith.constant 393344 : i64 + %36 = pto.alloc_tile addr = %c393216_i64 : !pto.tile_buf + %c393472_i64 = arith.constant 393472 : i64 + %c393600_i64 = arith.constant 393600 : i64 + %37 = pto.alloc_tile addr = %c393472_i64 : !pto.tile_buf + %c393728_i64 = arith.constant 393728 : i64 + %c393856_i64 = arith.constant 393856 : i64 + %38 = pto.alloc_tile addr = %c393728_i64 : !pto.tile_buf + %c393984_i64 = arith.constant 393984 : i64 + %c394112_i64 = arith.constant 394112 : i64 + %39 = pto.alloc_tile addr = %c393984_i64 : !pto.tile_buf + %c394240_i64 = arith.constant 394240 : i64 + %c394368_i64 = arith.constant 394368 : i64 + %40 = pto.alloc_tile addr = %c394240_i64 : !pto.tile_buf + %c394496_i64 = arith.constant 394496 : i64 + %c394624_i64 = arith.constant 394624 : i64 + %41 = pto.alloc_tile addr = %c394496_i64 : !pto.tile_buf + %cst = arith.constant 0.0883883461 : f32 + %cst_3 = arith.constant 1.000000e+00 : f32 + %c2048 = arith.constant 2048 : index + %42 = pto.make_tensor_view %arg1, shape = [%c2048, %c128_0], strides = [%c128_0, %c1] : !pto.tensor_view + scf.for %arg2 = %14 to %18 step %c1 { + %43 = arith.muli %arg2, %c128 : index + %c394752_i64 = arith.constant 394752 : i64 + %44 = pto.alloc_tile addr = %c394752_i64 : !pto.tile_buf + %qk_pop_0 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_0 = pto.partition_view %qk_pop_0, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_0 : !pto.partition_tensor_view<32x256xf32>) outs(%44 : !pto.tile_buf) + pto.tmuls ins(%44, %cst : !pto.tile_buf, f32) outs(%44 : !pto.tile_buf) + %r36_0 = pto.alloc_tile addr = %c393216_i64 : !pto.tile_buf + %r37_0 = pto.alloc_tile addr = %c393472_i64 : !pto.tile_buf + %r38_0 = pto.alloc_tile addr = %c393728_i64 : !pto.tile_buf + %r39_0 = pto.alloc_tile addr = %c393984_i64 : !pto.tile_buf + %r40_0 = pto.alloc_tile addr = %c394240_i64 : !pto.tile_buf + pto.trowmax ins(%44, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_0 : !pto.tile_buf) + %45 = pto.treshape %r37_0 : !pto.tile_buf -> !pto.tile_buf + %46 = pto.treshape %r36_0 : !pto.tile_buf -> !pto.tile_buf + %47 = pto.treshape %r40_0 : !pto.tile_buf -> !pto.tile_buf + %48 = pto.treshape %r38_0 : !pto.tile_buf -> !pto.tile_buf + %49 = pto.treshape %r39_0 : !pto.tile_buf -> !pto.tile_buf + pto.trowexpandsub ins(%44, %r37_0 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmuls ins(%45, %cst_3 : !pto.tile_buf, f32) outs(%46 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r38_0 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_0 = pto.talloc_to_aic {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_push_part_0 = pto.partition_view %p_push_0, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_0 : !pto.partition_tensor_view<32x256xf16>) + %qk_pop_part_0_r1 = pto.partition_view %qk_pop_0, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_0_r1 : !pto.partition_tensor_view<32x256xf32>) outs(%44 : !pto.tile_buf) + pto.tfree_from_aic(%qk_pop_0 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + pto.tmuls ins(%44, %cst : !pto.tile_buf, f32) outs(%44 : !pto.tile_buf) + %r36_0_r1 = pto.alloc_tile addr = %c393344_i64 : !pto.tile_buf + %r37_0_r1 = pto.alloc_tile addr = %c393600_i64 : !pto.tile_buf + %r38_0_r1 = pto.alloc_tile addr = %c393856_i64 : !pto.tile_buf + %r39_0_r1 = pto.alloc_tile addr = %c394112_i64 : !pto.tile_buf + %r40_0_r1 = pto.alloc_tile addr = %c394368_i64 : !pto.tile_buf + pto.trowmax ins(%44, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_0_r1 : !pto.tile_buf) + %qk0_r1_max = pto.treshape %r37_0_r1 : !pto.tile_buf -> !pto.tile_buf + %qk0_r1_gmax = pto.treshape %r36_0_r1 : !pto.tile_buf -> !pto.tile_buf + %qk0_r1_tmp = pto.treshape %r40_0_r1 : !pto.tile_buf -> !pto.tile_buf + %qk0_r1_lsum = pto.treshape %r38_0_r1 : !pto.tile_buf -> !pto.tile_buf + %qk0_r1_gsum = pto.treshape %r39_0_r1 : !pto.tile_buf -> !pto.tile_buf + pto.trowexpandsub ins(%44, %r37_0_r1 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmuls ins(%qk0_r1_max, %cst_3 : !pto.tile_buf, f32) outs(%qk0_r1_gmax : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r38_0_r1 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_part_0_r1 = pto.partition_view %p_push_0, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_0_r1 : !pto.partition_tensor_view<32x256xf16>) + pto.tpush_to_aic(%p_push_0 : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + %c394752_i64_4 = arith.constant 394752 : i64 + %50 = pto.alloc_tile addr = %c394752_i64_4 : !pto.tile_buf + %qk_pop_1 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_1 = pto.partition_view %qk_pop_1, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_1 : !pto.partition_tensor_view<32x256xf32>) outs(%50 : !pto.tile_buf) + pto.tmuls ins(%50, %cst : !pto.tile_buf, f32) outs(%50 : !pto.tile_buf) + %r36_1 = pto.alloc_tile addr = %c393216_i64 : !pto.tile_buf + %r37_1 = pto.alloc_tile addr = %c393472_i64 : !pto.tile_buf + %r38_1 = pto.alloc_tile addr = %c393728_i64 : !pto.tile_buf + %r39_1 = pto.alloc_tile addr = %c393984_i64 : !pto.tile_buf + %r41_1 = pto.alloc_tile addr = %c394496_i64 : !pto.tile_buf + pto.trowmax ins(%50, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_1 : !pto.tile_buf) + %51 = pto.treshape %r37_1 : !pto.tile_buf -> !pto.tile_buf + %52 = pto.treshape %r36_1 : !pto.tile_buf -> !pto.tile_buf + %53 = pto.treshape %r41_1 : !pto.tile_buf -> !pto.tile_buf + %54 = pto.treshape %r38_1 : !pto.tile_buf -> !pto.tile_buf + %55 = pto.treshape %r39_1 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%51, %52 : !pto.tile_buf, !pto.tile_buf) outs(%51 : !pto.tile_buf) + pto.tsub ins(%52, %51 : !pto.tile_buf, !pto.tile_buf) outs(%53 : !pto.tile_buf) + pto.tmuls ins(%51, %cst_3 : !pto.tile_buf, f32) outs(%52 : !pto.tile_buf) + pto.trowexpandsub ins(%50, %r37_1 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%53 : !pto.tile_buf) outs(%53 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%54, %53 : !pto.tile_buf, !pto.tile_buf) outs(%54 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r39_1 : !pto.tile_buf) + pto.tadd ins(%54, %55 : !pto.tile_buf, !pto.tile_buf) outs(%54 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_1 = pto.talloc_to_aic {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_push_part_1 = pto.partition_view %p_push_1, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_1 : !pto.partition_tensor_view<32x256xf16>) + %qk_pop_part_1_r1 = pto.partition_view %qk_pop_1, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_1_r1 : !pto.partition_tensor_view<32x256xf32>) outs(%50 : !pto.tile_buf) + pto.tfree_from_aic(%qk_pop_1 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + pto.tmuls ins(%50, %cst : !pto.tile_buf, f32) outs(%50 : !pto.tile_buf) + %r36_1_r1 = pto.alloc_tile addr = %c393344_i64 : !pto.tile_buf + %r37_1_r1 = pto.alloc_tile addr = %c393600_i64 : !pto.tile_buf + %r38_1_r1 = pto.alloc_tile addr = %c393856_i64 : !pto.tile_buf + %r39_1_r1 = pto.alloc_tile addr = %c394112_i64 : !pto.tile_buf + %r41_1_r1 = pto.alloc_tile addr = %c394624_i64 : !pto.tile_buf + pto.trowmax ins(%50, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_1_r1 : !pto.tile_buf) + %qk1_r1_max = pto.treshape %r37_1_r1 : !pto.tile_buf -> !pto.tile_buf + %qk1_r1_gmax = pto.treshape %r36_1_r1 : !pto.tile_buf -> !pto.tile_buf + %qk1_r1_diff = pto.treshape %r41_1_r1 : !pto.tile_buf -> !pto.tile_buf + %qk1_r1_lsum = pto.treshape %r38_1_r1 : !pto.tile_buf -> !pto.tile_buf + %qk1_r1_gsum = pto.treshape %r39_1_r1 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%qk1_r1_max, %qk1_r1_gmax : !pto.tile_buf, !pto.tile_buf) outs(%qk1_r1_max : !pto.tile_buf) + pto.tsub ins(%qk1_r1_gmax, %qk1_r1_max : !pto.tile_buf, !pto.tile_buf) outs(%qk1_r1_diff : !pto.tile_buf) + pto.tmuls ins(%qk1_r1_max, %cst_3 : !pto.tile_buf, f32) outs(%qk1_r1_gmax : !pto.tile_buf) + pto.trowexpandsub ins(%50, %r37_1_r1 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%qk1_r1_diff : !pto.tile_buf) outs(%qk1_r1_diff : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%qk1_r1_lsum, %qk1_r1_diff : !pto.tile_buf, !pto.tile_buf) outs(%qk1_r1_lsum : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r39_1_r1 : !pto.tile_buf) + pto.tadd ins(%qk1_r1_lsum, %qk1_r1_gsum : !pto.tile_buf, !pto.tile_buf) outs(%qk1_r1_lsum : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_part_1_r1 = pto.partition_view %p_push_1, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_1_r1 : !pto.partition_tensor_view<32x256xf16>) + pto.tpush_to_aic(%p_push_1 : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + %c394752_i64_5 = arith.constant 394752 : i64 + %56 = pto.alloc_tile addr = %c394752_i64_5 : !pto.tile_buf + %pv_pop_0 = pto.tpop_from_aic {id = 27, split = 1} -> !pto.tensor_view<64x128xf32> + %pv_pop_part_0 = pto.partition_view %pv_pop_0, offsets = [%c0, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<64x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_0 : !pto.partition_tensor_view<64x128xf32>) outs(%56 : !pto.tile_buf) + pto.tmov ins(%56 : !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_0 : !pto.tensor_view<64x128xf32>) {id = 27, split = 1} + %c394752_i64_6 = arith.constant 394752 : i64 + %57 = pto.alloc_tile addr = %c394752_i64_6 : !pto.tile_buf + %qk_pop_2 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_2 = pto.partition_view %qk_pop_2, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_2 : !pto.partition_tensor_view<32x256xf32>) outs(%57 : !pto.tile_buf) + pto.tmuls ins(%57, %cst : !pto.tile_buf, f32) outs(%57 : !pto.tile_buf) + %r36_2 = pto.alloc_tile addr = %c393216_i64 : !pto.tile_buf + %r37_2 = pto.alloc_tile addr = %c393472_i64 : !pto.tile_buf + %r38_2 = pto.alloc_tile addr = %c393728_i64 : !pto.tile_buf + %r39_2 = pto.alloc_tile addr = %c393984_i64 : !pto.tile_buf + %r40_2 = pto.alloc_tile addr = %c394240_i64 : !pto.tile_buf + pto.trowmax ins(%57, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_2 : !pto.tile_buf) + %58 = pto.treshape %r37_2 : !pto.tile_buf -> !pto.tile_buf + %59 = pto.treshape %r36_2 : !pto.tile_buf -> !pto.tile_buf + %60 = pto.treshape %r40_2 : !pto.tile_buf -> !pto.tile_buf + %61 = pto.treshape %r38_2 : !pto.tile_buf -> !pto.tile_buf + %62 = pto.treshape %r39_2 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%58, %59 : !pto.tile_buf, !pto.tile_buf) outs(%58 : !pto.tile_buf) + pto.tsub ins(%59, %58 : !pto.tile_buf, !pto.tile_buf) outs(%60 : !pto.tile_buf) + pto.tmuls ins(%58, %cst_3 : !pto.tile_buf, f32) outs(%59 : !pto.tile_buf) + pto.trowexpandsub ins(%57, %r37_2 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%60 : !pto.tile_buf) outs(%60 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%61, %60 : !pto.tile_buf, !pto.tile_buf) outs(%61 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r39_2 : !pto.tile_buf) + pto.tadd ins(%61, %62 : !pto.tile_buf, !pto.tile_buf) outs(%61 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_2 = pto.talloc_to_aic {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_push_part_2 = pto.partition_view %p_push_2, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_2 : !pto.partition_tensor_view<32x256xf16>) + %qk_pop_part_2_r1 = pto.partition_view %qk_pop_2, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_2_r1 : !pto.partition_tensor_view<32x256xf32>) outs(%57 : !pto.tile_buf) + pto.tfree_from_aic(%qk_pop_2 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + pto.tmuls ins(%57, %cst : !pto.tile_buf, f32) outs(%57 : !pto.tile_buf) + %r36_2_r1 = pto.alloc_tile addr = %c393344_i64 : !pto.tile_buf + %r37_2_r1 = pto.alloc_tile addr = %c393600_i64 : !pto.tile_buf + %r38_2_r1 = pto.alloc_tile addr = %c393856_i64 : !pto.tile_buf + %r39_2_r1 = pto.alloc_tile addr = %c394112_i64 : !pto.tile_buf + %r40_2_r1 = pto.alloc_tile addr = %c394368_i64 : !pto.tile_buf + pto.trowmax ins(%57, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_2_r1 : !pto.tile_buf) + %qk2_r1_max = pto.treshape %r37_2_r1 : !pto.tile_buf -> !pto.tile_buf + %qk2_r1_gmax = pto.treshape %r36_2_r1 : !pto.tile_buf -> !pto.tile_buf + %qk2_r1_diff = pto.treshape %r40_2_r1 : !pto.tile_buf -> !pto.tile_buf + %qk2_r1_lsum = pto.treshape %r38_2_r1 : !pto.tile_buf -> !pto.tile_buf + %qk2_r1_gsum = pto.treshape %r39_2_r1 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%qk2_r1_max, %qk2_r1_gmax : !pto.tile_buf, !pto.tile_buf) outs(%qk2_r1_max : !pto.tile_buf) + pto.tsub ins(%qk2_r1_gmax, %qk2_r1_max : !pto.tile_buf, !pto.tile_buf) outs(%qk2_r1_diff : !pto.tile_buf) + pto.tmuls ins(%qk2_r1_max, %cst_3 : !pto.tile_buf, f32) outs(%qk2_r1_gmax : !pto.tile_buf) + pto.trowexpandsub ins(%57, %r37_2_r1 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%qk2_r1_diff : !pto.tile_buf) outs(%qk2_r1_diff : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%qk2_r1_lsum, %qk2_r1_diff : !pto.tile_buf, !pto.tile_buf) outs(%qk2_r1_lsum : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r39_2_r1 : !pto.tile_buf) + pto.tadd ins(%qk2_r1_lsum, %qk2_r1_gsum : !pto.tile_buf, !pto.tile_buf) outs(%qk2_r1_lsum : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_part_2_r1 = pto.partition_view %p_push_2, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_2_r1 : !pto.partition_tensor_view<32x256xf16>) + pto.tpush_to_aic(%p_push_2 : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + %c394752_i64_7 = arith.constant 394752 : i64 + %63 = pto.alloc_tile addr = %c394752_i64_7 : !pto.tile_buf + %pv_pop_1 = pto.tpop_from_aic {id = 27, split = 1} -> !pto.tensor_view<64x128xf32> + %pv_pop_part_1 = pto.partition_view %pv_pop_1, offsets = [%c0, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<64x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_1 : !pto.partition_tensor_view<64x128xf32>) outs(%63 : !pto.tile_buf) + pto.trowexpandmul ins(%35, %41 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %63 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_1 : !pto.tensor_view<64x128xf32>) {id = 27, split = 1} + %c394752_i64_8 = arith.constant 394752 : i64 + %64 = pto.alloc_tile addr = %c394752_i64_8 : !pto.tile_buf + %qk_pop_3 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_3 = pto.partition_view %qk_pop_3, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_3 : !pto.partition_tensor_view<32x256xf32>) outs(%64 : !pto.tile_buf) + pto.tmuls ins(%64, %cst : !pto.tile_buf, f32) outs(%64 : !pto.tile_buf) + %r36_3 = pto.alloc_tile addr = %c393216_i64 : !pto.tile_buf + %r37_3 = pto.alloc_tile addr = %c393472_i64 : !pto.tile_buf + %r38_3 = pto.alloc_tile addr = %c393728_i64 : !pto.tile_buf + %r39_3 = pto.alloc_tile addr = %c393984_i64 : !pto.tile_buf + %r41_3 = pto.alloc_tile addr = %c394496_i64 : !pto.tile_buf + pto.trowmax ins(%64, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_3 : !pto.tile_buf) + %65 = pto.treshape %r37_3 : !pto.tile_buf -> !pto.tile_buf + %66 = pto.treshape %r36_3 : !pto.tile_buf -> !pto.tile_buf + %67 = pto.treshape %r41_3 : !pto.tile_buf -> !pto.tile_buf + %68 = pto.treshape %r38_3 : !pto.tile_buf -> !pto.tile_buf + %69 = pto.treshape %r39_3 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%65, %66 : !pto.tile_buf, !pto.tile_buf) outs(%65 : !pto.tile_buf) + pto.tsub ins(%66, %65 : !pto.tile_buf, !pto.tile_buf) outs(%67 : !pto.tile_buf) + pto.tmuls ins(%65, %cst_3 : !pto.tile_buf, f32) outs(%66 : !pto.tile_buf) + pto.trowexpandsub ins(%64, %r37_3 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%67 : !pto.tile_buf) outs(%67 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%68, %67 : !pto.tile_buf, !pto.tile_buf) outs(%68 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r39_3 : !pto.tile_buf) + pto.tadd ins(%68, %69 : !pto.tile_buf, !pto.tile_buf) outs(%68 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_3 = pto.talloc_to_aic {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_push_part_3 = pto.partition_view %p_push_3, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_3 : !pto.partition_tensor_view<32x256xf16>) + %qk_pop_part_3_r1 = pto.partition_view %qk_pop_3, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_3_r1 : !pto.partition_tensor_view<32x256xf32>) outs(%64 : !pto.tile_buf) + pto.tfree_from_aic(%qk_pop_3 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + pto.tmuls ins(%64, %cst : !pto.tile_buf, f32) outs(%64 : !pto.tile_buf) + %r36_3_r1 = pto.alloc_tile addr = %c393344_i64 : !pto.tile_buf + %r37_3_r1 = pto.alloc_tile addr = %c393600_i64 : !pto.tile_buf + %r38_3_r1 = pto.alloc_tile addr = %c393856_i64 : !pto.tile_buf + %r39_3_r1 = pto.alloc_tile addr = %c394112_i64 : !pto.tile_buf + %r41_3_r1 = pto.alloc_tile addr = %c394624_i64 : !pto.tile_buf + pto.trowmax ins(%64, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_3_r1 : !pto.tile_buf) + %qk3_r1_max = pto.treshape %r37_3_r1 : !pto.tile_buf -> !pto.tile_buf + %qk3_r1_gmax = pto.treshape %r36_3_r1 : !pto.tile_buf -> !pto.tile_buf + %qk3_r1_diff = pto.treshape %r41_3_r1 : !pto.tile_buf -> !pto.tile_buf + %qk3_r1_lsum = pto.treshape %r38_3_r1 : !pto.tile_buf -> !pto.tile_buf + %qk3_r1_gsum = pto.treshape %r39_3_r1 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%qk3_r1_max, %qk3_r1_gmax : !pto.tile_buf, !pto.tile_buf) outs(%qk3_r1_max : !pto.tile_buf) + pto.tsub ins(%qk3_r1_gmax, %qk3_r1_max : !pto.tile_buf, !pto.tile_buf) outs(%qk3_r1_diff : !pto.tile_buf) + pto.tmuls ins(%qk3_r1_max, %cst_3 : !pto.tile_buf, f32) outs(%qk3_r1_gmax : !pto.tile_buf) + pto.trowexpandsub ins(%64, %r37_3_r1 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%qk3_r1_diff : !pto.tile_buf) outs(%qk3_r1_diff : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%qk3_r1_lsum, %qk3_r1_diff : !pto.tile_buf, !pto.tile_buf) outs(%qk3_r1_lsum : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r39_3_r1 : !pto.tile_buf) + pto.tadd ins(%qk3_r1_lsum, %qk3_r1_gsum : !pto.tile_buf, !pto.tile_buf) outs(%qk3_r1_lsum : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_part_3_r1 = pto.partition_view %p_push_3, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_3_r1 : !pto.partition_tensor_view<32x256xf16>) + pto.tpush_to_aic(%p_push_3 : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + %c7 = arith.constant 7 : index + scf.for %arg3 = %c1 to %c7 step %c1 { + %c394752_i64_11 = arith.constant 394752 : i64 + %74 = pto.alloc_tile addr = %c394752_i64_11 : !pto.tile_buf + %pv_pop_2 = pto.tpop_from_aic {id = 27, split = 1} -> !pto.tensor_view<64x128xf32> + %pv_pop_part_2 = pto.partition_view %pv_pop_2, offsets = [%c0, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<64x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_2 : !pto.partition_tensor_view<64x128xf32>) outs(%74 : !pto.tile_buf) + pto.trowexpandmul ins(%35, %40 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %74 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_2 : !pto.tensor_view<64x128xf32>) {id = 27, split = 1} + %c394752_i64_12 = arith.constant 394752 : i64 + %75 = pto.alloc_tile addr = %c394752_i64_12 : !pto.tile_buf + %qk_pop_4 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_4 = pto.partition_view %qk_pop_4, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_4 : !pto.partition_tensor_view<32x256xf32>) outs(%75 : !pto.tile_buf) + pto.tmuls ins(%75, %cst : !pto.tile_buf, f32) outs(%75 : !pto.tile_buf) + %r36_4 = pto.alloc_tile addr = %c393216_i64 : !pto.tile_buf + %r37_4 = pto.alloc_tile addr = %c393472_i64 : !pto.tile_buf + %r38_4 = pto.alloc_tile addr = %c393728_i64 : !pto.tile_buf + %r39_4 = pto.alloc_tile addr = %c393984_i64 : !pto.tile_buf + %r40_4 = pto.alloc_tile addr = %c394240_i64 : !pto.tile_buf + pto.trowmax ins(%75, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_4 : !pto.tile_buf) + %76 = pto.treshape %r37_4 : !pto.tile_buf -> !pto.tile_buf + %77 = pto.treshape %r36_4 : !pto.tile_buf -> !pto.tile_buf + %78 = pto.treshape %r40_4 : !pto.tile_buf -> !pto.tile_buf + %79 = pto.treshape %r38_4 : !pto.tile_buf -> !pto.tile_buf + %80 = pto.treshape %r39_4 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%76, %77 : !pto.tile_buf, !pto.tile_buf) outs(%76 : !pto.tile_buf) + pto.tsub ins(%77, %76 : !pto.tile_buf, !pto.tile_buf) outs(%78 : !pto.tile_buf) + pto.tmuls ins(%76, %cst_3 : !pto.tile_buf, f32) outs(%77 : !pto.tile_buf) + pto.trowexpandsub ins(%75, %r37_4 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%78 : !pto.tile_buf) outs(%78 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%79, %78 : !pto.tile_buf, !pto.tile_buf) outs(%79 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r39_4 : !pto.tile_buf) + pto.tadd ins(%79, %80 : !pto.tile_buf, !pto.tile_buf) outs(%79 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_4 = pto.talloc_to_aic {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_push_part_4 = pto.partition_view %p_push_4, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_4 : !pto.partition_tensor_view<32x256xf16>) + %qk_pop_part_4_r1 = pto.partition_view %qk_pop_4, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_4_r1 : !pto.partition_tensor_view<32x256xf32>) outs(%75 : !pto.tile_buf) + pto.tfree_from_aic(%qk_pop_4 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + pto.tmuls ins(%75, %cst : !pto.tile_buf, f32) outs(%75 : !pto.tile_buf) + %r36_4_r1 = pto.alloc_tile addr = %c393344_i64 : !pto.tile_buf + %r37_4_r1 = pto.alloc_tile addr = %c393600_i64 : !pto.tile_buf + %r38_4_r1 = pto.alloc_tile addr = %c393856_i64 : !pto.tile_buf + %r39_4_r1 = pto.alloc_tile addr = %c394112_i64 : !pto.tile_buf + %r40_4_r1 = pto.alloc_tile addr = %c394368_i64 : !pto.tile_buf + pto.trowmax ins(%75, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_4_r1 : !pto.tile_buf) + %qk4_r1_max = pto.treshape %r37_4_r1 : !pto.tile_buf -> !pto.tile_buf + %qk4_r1_gmax = pto.treshape %r36_4_r1 : !pto.tile_buf -> !pto.tile_buf + %qk4_r1_diff = pto.treshape %r40_4_r1 : !pto.tile_buf -> !pto.tile_buf + %qk4_r1_lsum = pto.treshape %r38_4_r1 : !pto.tile_buf -> !pto.tile_buf + %qk4_r1_gsum = pto.treshape %r39_4_r1 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%qk4_r1_max, %qk4_r1_gmax : !pto.tile_buf, !pto.tile_buf) outs(%qk4_r1_max : !pto.tile_buf) + pto.tsub ins(%qk4_r1_gmax, %qk4_r1_max : !pto.tile_buf, !pto.tile_buf) outs(%qk4_r1_diff : !pto.tile_buf) + pto.tmuls ins(%qk4_r1_max, %cst_3 : !pto.tile_buf, f32) outs(%qk4_r1_gmax : !pto.tile_buf) + pto.trowexpandsub ins(%75, %r37_4_r1 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%qk4_r1_diff : !pto.tile_buf) outs(%qk4_r1_diff : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%qk4_r1_lsum, %qk4_r1_diff : !pto.tile_buf, !pto.tile_buf) outs(%qk4_r1_lsum : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r39_4_r1 : !pto.tile_buf) + pto.tadd ins(%qk4_r1_lsum, %qk4_r1_gsum : !pto.tile_buf, !pto.tile_buf) outs(%qk4_r1_lsum : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_part_4_r1 = pto.partition_view %p_push_4, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_4_r1 : !pto.partition_tensor_view<32x256xf16>) + pto.tpush_to_aic(%p_push_4 : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + %c394752_i64_13 = arith.constant 394752 : i64 + %81 = pto.alloc_tile addr = %c394752_i64_13 : !pto.tile_buf + %pv_pop_3 = pto.tpop_from_aic {id = 27, split = 1} -> !pto.tensor_view<64x128xf32> + %pv_pop_part_3 = pto.partition_view %pv_pop_3, offsets = [%c0, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<64x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_3 : !pto.partition_tensor_view<64x128xf32>) outs(%81 : !pto.tile_buf) + pto.trowexpandmul ins(%35, %41 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %81 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_3 : !pto.tensor_view<64x128xf32>) {id = 27, split = 1} + %c394752_i64_14 = arith.constant 394752 : i64 + %82 = pto.alloc_tile addr = %c394752_i64_14 : !pto.tile_buf + %qk_pop_5 = pto.tpop_from_aic {id = 25, split = 0} -> !pto.tensor_view<128x256xf32> + %qk_pop_part_5 = pto.partition_view %qk_pop_5, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_5 : !pto.partition_tensor_view<32x256xf32>) outs(%82 : !pto.tile_buf) + pto.tmuls ins(%82, %cst : !pto.tile_buf, f32) outs(%82 : !pto.tile_buf) + %r36_5 = pto.alloc_tile addr = %c393216_i64 : !pto.tile_buf + %r37_5 = pto.alloc_tile addr = %c393472_i64 : !pto.tile_buf + %r38_5 = pto.alloc_tile addr = %c393728_i64 : !pto.tile_buf + %r39_5 = pto.alloc_tile addr = %c393984_i64 : !pto.tile_buf + %r41_5 = pto.alloc_tile addr = %c394496_i64 : !pto.tile_buf + pto.trowmax ins(%82, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_5 : !pto.tile_buf) + %83 = pto.treshape %r37_5 : !pto.tile_buf -> !pto.tile_buf + %84 = pto.treshape %r36_5 : !pto.tile_buf -> !pto.tile_buf + %85 = pto.treshape %r41_5 : !pto.tile_buf -> !pto.tile_buf + %86 = pto.treshape %r38_5 : !pto.tile_buf -> !pto.tile_buf + %87 = pto.treshape %r39_5 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%83, %84 : !pto.tile_buf, !pto.tile_buf) outs(%83 : !pto.tile_buf) + pto.tsub ins(%84, %83 : !pto.tile_buf, !pto.tile_buf) outs(%85 : !pto.tile_buf) + pto.tmuls ins(%83, %cst_3 : !pto.tile_buf, f32) outs(%84 : !pto.tile_buf) + pto.trowexpandsub ins(%82, %r37_5 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%85 : !pto.tile_buf) outs(%85 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%86, %85 : !pto.tile_buf, !pto.tile_buf) outs(%86 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r39_5 : !pto.tile_buf) + pto.tadd ins(%86, %87 : !pto.tile_buf, !pto.tile_buf) outs(%86 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_5 = pto.talloc_to_aic {id = 30, split = 0} -> !pto.tensor_view<128x256xf16> + %p_push_part_5 = pto.partition_view %p_push_5, offsets = [%31, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_5 : !pto.partition_tensor_view<32x256xf16>) + %qk_pop_part_5_r1 = pto.partition_view %qk_pop_5, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf32> -> !pto.partition_tensor_view<32x256xf32> + pto.tload ins(%qk_pop_part_5_r1 : !pto.partition_tensor_view<32x256xf32>) outs(%82 : !pto.tile_buf) + pto.tfree_from_aic(%qk_pop_5 : !pto.tensor_view<128x256xf32>) {id = 25, split = 0} + pto.tmuls ins(%82, %cst : !pto.tile_buf, f32) outs(%82 : !pto.tile_buf) + %r36_5_r1 = pto.alloc_tile addr = %c393344_i64 : !pto.tile_buf + %r37_5_r1 = pto.alloc_tile addr = %c393600_i64 : !pto.tile_buf + %r38_5_r1 = pto.alloc_tile addr = %c393856_i64 : !pto.tile_buf + %r39_5_r1 = pto.alloc_tile addr = %c394112_i64 : !pto.tile_buf + %r41_5_r1 = pto.alloc_tile addr = %c394624_i64 : !pto.tile_buf + pto.trowmax ins(%82, %33 : !pto.tile_buf, !pto.tile_buf) outs(%r37_5_r1 : !pto.tile_buf) + %qk5_r1_max = pto.treshape %r37_5_r1 : !pto.tile_buf -> !pto.tile_buf + %qk5_r1_gmax = pto.treshape %r36_5_r1 : !pto.tile_buf -> !pto.tile_buf + %qk5_r1_diff = pto.treshape %r41_5_r1 : !pto.tile_buf -> !pto.tile_buf + %qk5_r1_lsum = pto.treshape %r38_5_r1 : !pto.tile_buf -> !pto.tile_buf + %qk5_r1_gsum = pto.treshape %r39_5_r1 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%qk5_r1_max, %qk5_r1_gmax : !pto.tile_buf, !pto.tile_buf) outs(%qk5_r1_max : !pto.tile_buf) + pto.tsub ins(%qk5_r1_gmax, %qk5_r1_max : !pto.tile_buf, !pto.tile_buf) outs(%qk5_r1_diff : !pto.tile_buf) + pto.tmuls ins(%qk5_r1_max, %cst_3 : !pto.tile_buf, f32) outs(%qk5_r1_gmax : !pto.tile_buf) + pto.trowexpandsub ins(%82, %r37_5_r1 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%qk5_r1_diff : !pto.tile_buf) outs(%qk5_r1_diff : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%qk5_r1_lsum, %qk5_r1_diff : !pto.tile_buf, !pto.tile_buf) outs(%qk5_r1_lsum : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%r39_5_r1 : !pto.tile_buf) + pto.tadd ins(%qk5_r1_lsum, %qk5_r1_gsum : !pto.tile_buf, !pto.tile_buf) outs(%qk5_r1_lsum : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + %p_push_part_5_r1 = pto.partition_view %p_push_5, offsets = [%row_slice_1, %c0], sizes = [%c32, %c256] : !pto.tensor_view<128x256xf16> -> !pto.partition_tensor_view<32x256xf16> + pto.tstore ins(%34 : !pto.tile_buf) outs(%p_push_part_5_r1 : !pto.partition_tensor_view<32x256xf16>) + pto.tpush_to_aic(%p_push_5 : !pto.tensor_view<128x256xf16>) {id = 30, split = 0} + } + %c394752_i64_9 = arith.constant 394752 : i64 + %70 = pto.alloc_tile addr = %c394752_i64_9 : !pto.tile_buf + %pv_pop_4 = pto.tpop_from_aic {id = 27, split = 1} -> !pto.tensor_view<64x128xf32> + %pv_pop_part_4 = pto.partition_view %pv_pop_4, offsets = [%c0, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<64x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_4 : !pto.partition_tensor_view<64x128xf32>) outs(%70 : !pto.tile_buf) + pto.trowexpandmul ins(%35, %40 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %70 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_4 : !pto.tensor_view<64x128xf32>) {id = 27, split = 1} + %c394752_i64_10 = arith.constant 394752 : i64 + %71 = pto.alloc_tile addr = %c394752_i64_10 : !pto.tile_buf + %pv_pop_5 = pto.tpop_from_aic {id = 27, split = 1} -> !pto.tensor_view<64x128xf32> + %pv_pop_part_5 = pto.partition_view %pv_pop_5, offsets = [%c0, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view<64x128xf32> -> !pto.partition_tensor_view<64x128xf32> + pto.tload ins(%pv_pop_part_5 : !pto.partition_tensor_view<64x128xf32>) outs(%71 : !pto.tile_buf) + pto.trowexpandmul ins(%35, %41 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %71 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aic(%pv_pop_5 : !pto.tensor_view<64x128xf32>) {id = 27, split = 1} + pto.trowexpanddiv ins(%35, %38 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + %72 = arith.addi %43, %31 : index + %73 = pto.partition_view %42, offsets = [%72, %c0], sizes = [%c64, %c128_0] : !pto.tensor_view -> !pto.partition_tensor_view<64x128xf32> + pto.tstore ins(%35 : !pto.tile_buf) outs(%73 : !pto.partition_tensor_view<64x128xf32>) + } + return + } + func.func @call_both(%arg0: memref<256xi64>, %q: !pto.ptr, %k: !pto.ptr, %v: !pto.ptr, %p_fifo: !pto.ptr, %o_out: !pto.ptr, %qk_fifo: !pto.ptr, %pv_fifo: !pto.ptr) attributes {pto.entry} { + pto.set_ffts %arg0 : memref<256xi64> + call @cube_kernel(%qk_fifo, %q, %k, %v, %p_fifo, %pv_fifo) : (!pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr) -> () + call @vector_kernel(%qk_fifo, %o_out, %p_fifo, %pv_fifo) : (!pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr) -> () + return + } +} \ No newline at end of file diff --git a/examples/aot/flash_attention/ir_ref/gen_cpp.sh b/examples/aot/flash_attention/ir_ref/gen_cpp.sh new file mode 100644 index 00000000..fbc172be --- /dev/null +++ b/examples/aot/flash_attention/ir_ref/gen_cpp.sh @@ -0,0 +1,2 @@ +ptoas --pto-arch=a3 --pto-level=level3 --enable-insert-sync fa.pto > fa.cpp +ptoas --pto-arch=a3 --pto-level=level3 --enable-insert-sync fa_perf.pto > fa_perf.cpp diff --git a/examples/aot/flash_attention/ir_ref/launch_kernel/README.md b/examples/aot/flash_attention/ir_ref/launch_kernel/README.md new file mode 100644 index 00000000..4811068b --- /dev/null +++ b/examples/aot/flash_attention/ir_ref/launch_kernel/README.md @@ -0,0 +1,92 @@ +# `ir_ref/fa.cpp` — compile, launch, correctness / perf + +This folder mirrors the **`split_pipe`** flow (`caller.cpp` + `bisheng` shared library + `torch_npu` Python runner), but compiles the **checked-in IR reference** [`../fa.cpp`](../fa.cpp) produced from [`../fa.pto`](../fa.pto) (`ptoas --pto-arch=a3 --pto-level=level3 --enable-insert-sync`). + +## Baked-in geometry + +| Symbol | Value | +| ----- | ----- | +| `Q_ROWS` | 2048 | +| `HEAD` | 128 | +| `S1_TOTAL` | 4096 | +| `S1_TILE` | 256 (`NUM_TILES` = 16) | + +Host tensors and GM scratch sizes come from `split_pipe/kernels/fa_performance_builder.py` with **`FA_Q_ROWS=2048`**, **`FA_S1_TILE=256`**, **`FA_NUM_TILES=16`** (reload via `fa.build_env` after compile). + +**Launch grid:** this kernel divides **`NUM_TILES` (16)** across `get_block_num()`, not `NUM_Q_BLOCKS`. The runner therefore uses **`blockDim = min(NUM_TILES, num_cube_cores)`** (same idea as mapping tile loops to the grid). Using `NUM_Q_BLOCKS` here triggers device faults. + +## Setup + +From the `pto-dsl` repo root (Python imports `ptodsl`, `torch_npu`): + +```bash +cd /workdir/pto-dsl +pip install -e . +``` + +Environment (same as other `examples/aot` demos): + +- `ASCEND_TOOLKIT_HOME` +- `PTO_LIB_PATH` — directory containing `include/pto/` (default in many images: `/sources/pto-isa`) + +## Build `fa.so` + +```bash +cd /workdir/pto-dsl/examples/aot/flash_attention/ir_ref/launch_kernel +bash compile.sh +``` + +Produces `build_artifacts/fa.so` and `build_artifacts/fa.build_env`. + +Regenerate `../fa.cpp` from PTO when needed: + +```bash +cd /workdir/pto-dsl/examples/aot/flash_attention/ir_ref +bash gen_cpp.sh # or: ptoas --pto-arch=a3 --pto-level=level3 --enable-insert-sync fa.pto > fa.cpp +``` + +## Run correctness + benchmark + +```bash +cd /workdir/pto-dsl/examples/aot/flash_attention/ir_ref/launch_kernel +FA_BENCH_NO_PLOT=1 python3 run.py +``` + +Optional: `FA_BENCH_LENGTHS=4096` (default is already `4096`). + +## Results on reference hardware (reproduced here) + +**Environment:** CANN 8.5.0 toolkits, `bisheng --npu-arch=dav-2201`, `ptoas` from `/installers/ptoas-cli/bin/ptoas`, **dav‑2201** NPU, **24** cube cores (see `get_num_cube_cores()`), date **2026-04-30**. + +| Step | Outcome | +| ---- | ------- | +| **`bash compile.sh`** | **Success** (~4s); links `../fa.cpp` via `-DKERNEL_CPP=...` | +| **`python3 run.py`** | **Fails** at `torch.npu.synchronize()` after `call_kernel`: **ACL 507015**, **aicore exception**, **CCU instruction address check** (vector core backtrace in device log) | + +So **`fa.cpp` does not reach a completed device execution** in this configuration; **kernel latency / TFLOP/s are not reported** for it. + +**Related checks on the same machine:** + +- Applying the **`wait_flag_dev` / `ffts_cross_core_sync`** CV handshake from [`split_pipe/debug_cpp/forward_debug/README.md`](../../split_pipe/debug_cpp/forward_debug/README.md) to a copy of `fa.cpp` **still hit 507015** here — the IR layout differs from the forward-debug S256 kernel that clears the sync fault. +- **`split_pipe`** AOT at **`FA_S1_TILE=512`**:** same-style 507015** on synchronize (see [`split_pipe/README.md`](../../split_pipe/README.md) §A7). +- **`split_pipe`** AOT at **`FA_S1_TILE=256`**:** synchronize completes** but **`torch.testing.assert_close` fails** (NaNs / large drift — §A11). +- **`split_pipe/debug_cpp/forward_debug`** (`fa.ptoas.forward_edited.cpp`): **synchronize completes**; **numerics still fail** vs fp32 reference (documented there). +- **Baseline hand-written JIT** [`cpp_ref/split_pipe/run.py`](../../cpp_ref/split_pipe/run.py) **passes** correctness on this NPU. Closest reported row (**same `S1=4096`, `HEAD=128`, but `Q_ROWS=3072` and `tile_s1=512`**, not the IR reference geometry): + +```bash +cd /workdir/pto-dsl/examples/aot/flash_attention/cpp_ref/split_pipe +python3 run.py +# excerpt (S1=4096 case), 2026-04-30: +# JIT flash kernel : 0.120 ms/iter (54.307 TFLOP/s) +# npu_fused_infer_attention : 0.254 ms/iter (25.628 TFLOP/s) +``` + +Use that run as an upper-bound sanity reference only; it is **not** the same kernel or tensor shapes as `ir_ref/fa.cpp`. + +## Files + +| File | Role | +| ---- | ---- | +| `compile.sh` | `bisheng` → `build_artifacts/fa.so` | +| `caller.cpp` | `call_kernel` → `call_both<<>>` (same pattern as `split_pipe/caller.cpp`) | +| `run.py` | Loads `.so`, correctness vs `fa_reference`, bench vs `torch_npu.npu_fused_infer_attention_score` | diff --git a/examples/aot/flash_attention/ir_ref/launch_kernel/caller.cpp b/examples/aot/flash_attention/ir_ref/launch_kernel/caller.cpp new file mode 100644 index 00000000..87546451 --- /dev/null +++ b/examples/aot/flash_attention/ir_ref/launch_kernel/caller.cpp @@ -0,0 +1,32 @@ +#ifndef KERNEL_CPP +#error "KERNEL_CPP must be defined at compile time." +#endif + +#include + +extern "C" int rtGetC2cCtrlAddr(uint64_t *ctrlAddr, uint32_t *ctrlLen); + +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, + void *stream, + uint8_t *gmSlotBuffer, + uint8_t *q, + uint8_t *k, // K: [S1_TOTAL, HEAD] fp16 + uint8_t *v, + uint8_t *o) // output O: [Q_ROWS, HEAD] fp32 +{ + void *fftsAddr = nullptr; + uint32_t fftsLen = 0; + (void)rtGetC2cCtrlAddr(reinterpret_cast(&fftsAddr), &fftsLen); + (void)fftsLen; + + call_both<<>>( + (__gm__ int64_t *)fftsAddr, + (__gm__ float *)gmSlotBuffer, + (__gm__ half *)q, + (__gm__ half *)k, + (__gm__ half *)v, + (__gm__ float *)o); +} diff --git a/examples/aot/flash_attention/ir_ref/launch_kernel/compile.sh b/examples/aot/flash_attention/ir_ref/launch_kernel/compile.sh new file mode 100755 index 00000000..858eaa69 --- /dev/null +++ b/examples/aot/flash_attention/ir_ref/launch_kernel/compile.sh @@ -0,0 +1,46 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# CANN Open Software License Agreement Version 2.0 +# +# AOT-compile ../fa.cpp (ptoas output from fa.pto; regenerate via ../gen_cpp.sh) into host-loaded fa.so. +# Geometry baked into IR: Q_ROWS=2048, S1_TOTAL=4096 (NUM_TILES=16 × S1_TILE=256), HEAD=128. +# +# Usage: +# bash compile.sh + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ARTIFACT_DIR="${SCRIPT_DIR}/build_artifacts" +PTO_LIB_PATH="${PTO_LIB_PATH:-/sources/pto-isa}" +KERNEL_CPP="${SCRIPT_DIR}/../fa.cpp" +GENERATED_SO="${ARTIFACT_DIR}/fa.so" + +mkdir -p "${ARTIFACT_DIR}" +rm -f "${GENERATED_SO}" + +echo "==> bisheng ../fa.cpp -> ${GENERATED_SO}" +bisheng \ + -I"${PTO_LIB_PATH}/include" \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + -cce-enable-mix \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + -DKERNEL_CPP="\"${KERNEL_CPP}\"" \ + "${SCRIPT_DIR}/caller.cpp" \ + -o "${GENERATED_SO}" + +{ + echo "FA_NUM_TILES=16" + echo "FA_S1_TILE=256" + echo "FA_Q_ROWS=2048" +} >"${ARTIFACT_DIR}/fa.build_env" + +echo "Done." diff --git a/examples/aot/flash_attention/ir_ref/launch_kernel/run.py b/examples/aot/flash_attention/ir_ref/launch_kernel/run.py new file mode 100755 index 00000000..33194b19 --- /dev/null +++ b/examples/aot/flash_attention/ir_ref/launch_kernel/run.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# CANN Open Software License Agreement Version 2.0 +# +# Loads `build_artifacts/fa.so` built from `../fa.cpp` (ptoas-generated IR kernel). +# Host-side shapes must match the baked-in MLIR: Q_ROWS=2048, S1_TOTAL=4096, +# S1_TILE=256, NUM_TILES=16, HEAD=128 — kept in sync via `fa.build_env` + fa_performance_builder. +# +# Build first: `bash compile.sh` + +import ctypes +import importlib +import math +import os +import sys + +import torch +import torch_npu # noqa: F401 + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_SPLIT_PIPE_KERNELS = os.path.normpath(os.path.join(THIS_DIR, "..", "..", "split_pipe", "kernels")) +sys.path.insert(0, _SPLIT_PIPE_KERNELS) +import fa_performance_builder as fb # noqa: E402 + +from ptodsl import do_bench # noqa: E402 +from ptodsl.utils.npu_info import get_num_cube_cores, get_test_device # noqa: E402 + +ARTIFACT_DIR = os.path.join(THIS_DIR, "build_artifacts") +DEFAULT_PLOT_PATH = os.path.join(ARTIFACT_DIR, "fa_benchmark.png") + +# Single variant for bundled ../fa.cpp (4096 = 16 × 256). +DEFAULT_BENCH_LENGTHS = (4096,) + + +def _parse_bench_lengths(): + raw = os.environ.get("FA_BENCH_LENGTHS") + if not raw: + return DEFAULT_BENCH_LENGTHS + return tuple(int(x) for x in raw.split(",") if x.strip()) + + +ATOL = 1e-3 +RTOL = 1e-3 + + +def attn_flops_matmul_softmax_scale( + batch_size: int, + s_q: int, + s_k: int, + h: int, + include_scale: bool = True, + count_exp_as_flop: bool = True, + count_max_as_flop: bool = True, +): + flops_matmul = 4 * batch_size * s_q * s_k * h + flops_scale = (batch_size * s_q * s_k) if include_scale else 0 + + rows = batch_size * s_q + softmax_ops = 0 + if count_max_as_flop: + softmax_ops += rows * (s_k - 1) + softmax_ops += rows * s_k + if count_exp_as_flop: + softmax_ops += rows * s_k + softmax_ops += rows * (s_k - 1) + softmax_ops += rows * s_k + + return flops_matmul + flops_scale + softmax_ops + + +def get_block_dim() -> int: + # fa.cpp partitions NUM_TILES across get_block_num() / get_block_idx() (see cube/vec + # loops), not NUM_Q_BLOCKS — using NUM_Q_BLOCKS mis-launches the grid and faults (507015). + return min(fb.NUM_TILES, get_num_cube_cores()) + + +def get_slot_elems(block_dim: int) -> int: + return fb.GM_ELEMS_PER_BLOCK * block_dim + + +def num_tiles_for(seq_len: int) -> int: + s1_tile = fb.S1_TILE + if seq_len % s1_tile != 0: + raise ValueError(f"seq_len {seq_len} not divisible by S1_TILE={s1_tile}") + return seq_len // s1_tile + + +def _apply_build_env_matching_seq(seq_len: int) -> None: + if not os.path.isdir(ARTIFACT_DIR): + return + matches = [] + for fn in sorted(os.listdir(ARTIFACT_DIR)): + if not fn.endswith(".build_env"): + continue + path = os.path.join(ARTIFACT_DIR, fn) + kv = {} + try: + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + k, _, v = line.partition("=") + if k.strip(): + kv[k.strip()] = v.strip() + except OSError: + continue + try: + nt = int(kv["FA_NUM_TILES"]) + s1 = int(kv["FA_S1_TILE"]) + except (KeyError, ValueError): + continue + if nt * s1 != seq_len: + continue + matches.append((fn, kv)) + + if not matches: + return + + env_s1 = os.environ.get("FA_S1_TILE") + if env_s1 is not None: + for fn, kv in matches: + if kv.get("FA_S1_TILE") == env_s1.strip(): + for k, v in kv.items(): + os.environ[k] = v + return + + matches.sort(key=lambda x: (0 if x[0] == "fa.build_env" else 1, x[0])) + for _, kv in matches: + for k, v in kv.items(): + os.environ[k] = v + return + + +def lib_path_for(num_tiles: int) -> str: + if num_tiles == 16: + return os.path.join(ARTIFACT_DIR, "fa.so") + return os.path.join(ARTIFACT_DIR, f"fa_{num_tiles}.so") + + +def require_lib(num_tiles: int) -> str: + lib_path = lib_path_for(num_tiles) + if not os.path.exists(lib_path): + raise FileNotFoundError( + f"Missing prebuilt kernel: {lib_path}\n" + "Run `bash compile.sh` in ir_ref/launch_kernel first." + ) + return lib_path + + +def torch_to_ctypes(t: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(t.data_ptr()) + + +def load_lib(lib_path: str) -> ctypes.CDLL: + lib = ctypes.CDLL(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ] + lib.call_kernel.restype = None + return lib + + +def fa_reference(q, k, v): + scale = 1.0 / math.sqrt(q.shape[1]) + scores = q.float() @ k.float().T * scale + attn = torch.softmax(scores, dim=-1) + return (attn @ v.float()).float() + + +def fused_attention(q, k, v, is_causal=False): + scale = 1.0 / math.sqrt(q.shape[1]) + out, _ = torch_npu.npu_fused_infer_attention_score( + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + num_heads=1, + input_layout="BSH", + scale=scale, + next_tokens=0 if is_causal else 65535, + ) + return out.squeeze(0) + + +def test_flash(lib, device, num_tiles): + torch.manual_seed(0) + Q_ROWS = fb.Q_ROWS + HEAD = fb.HEAD + S1_TOTAL = fb.S1_TILE * num_tiles + GM_ELEMS_PER_BLOCK = fb.GM_ELEMS_PER_BLOCK + + block_dim = get_block_dim() + slot_elems = get_slot_elems(block_dim) + + q = torch.randn((Q_ROWS, HEAD), dtype=torch.float16, device=device) + k = torch.randn((S1_TOTAL, HEAD), dtype=torch.float16, device=device) + v = torch.randn((S1_TOTAL, HEAD), dtype=torch.float16, device=device) + + gm_slot = torch.zeros((slot_elems,), dtype=torch.float32, device=device) + o = torch.zeros((Q_ROWS, HEAD), dtype=torch.float32, device=device) + + stream_ptr = torch.npu.current_stream()._as_parameter_ + + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(gm_slot), + torch_to_ctypes(q), + torch_to_ctypes(k), + torch_to_ctypes(v), + torch_to_ctypes(o), + ) + torch.npu.synchronize() + + o_ref = fa_reference(q, k, v) + torch.testing.assert_close(o.cpu().float(), o_ref.cpu(), rtol=RTOL, atol=ATOL) + print( + f"[fa] q_rows={Q_ROWS} s1={S1_TOTAL} head={HEAD} " + f"({num_tiles} tiles, blockDim={block_dim}): PASSED " + f"(atol={ATOL}, rtol={RTOL}) GM/blk={GM_ELEMS_PER_BLOCK} fp32" + ) + + +def benchmark_flash(lib, device, num_tiles, warmup=10, iters=100): + torch.manual_seed(0) + Q_ROWS = fb.Q_ROWS + HEAD = fb.HEAD + S1_TILE = fb.S1_TILE + s1_total = S1_TILE * num_tiles + + block_dim = get_block_dim() + slot_elems = get_slot_elems(block_dim) + + q = torch.randn((Q_ROWS, HEAD), dtype=torch.float16, device=device) + k = torch.randn((s1_total, HEAD), dtype=torch.float16, device=device) + v = torch.randn((s1_total, HEAD), dtype=torch.float16, device=device) + + gm_slot = torch.zeros((slot_elems,), dtype=torch.float32, device=device) + o = torch.zeros((Q_ROWS, HEAD), dtype=torch.float32, device=device) + stream_ptr = torch.npu.current_stream()._as_parameter_ + + def run_kernel(): + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(gm_slot), + torch_to_ctypes(q), + torch_to_ctypes(k), + torch_to_ctypes(v), + torch_to_ctypes(o), + ) + + def run_reference(): + fused_attention(q, k, v) + + kernel_us = do_bench( + run_kernel, + warmup_iters=warmup, + benchmark_iters=iters, + unit="us", + flush_cache=False, + ) + ref_us = do_bench( + run_reference, + warmup_iters=warmup, + benchmark_iters=iters, + unit="us", + flush_cache=False, + ) + + run_kernel() + torch.npu.synchronize() + o_kernel = o.clone() + o_fused = fused_attention(q, k, v) + torch.npu.synchronize() + o_golden = fa_reference(q, k, v) + + diff_kernel = (o_kernel.cpu().float() - o_golden.cpu()).abs().max().item() + diff_fused = (o_fused.cpu().float() - o_golden.cpu()).abs().max().item() + torch.testing.assert_close( + o_kernel.cpu().float(), o_golden.cpu(), rtol=RTOL, atol=ATOL + ) + + flops = attn_flops_matmul_softmax_scale(1, Q_ROWS, s1_total, HEAD) + return { + "seq_len": s1_total, + "num_tiles": num_tiles, + "block_dim": block_dim, + "kernel_us": kernel_us, + "ref_us": ref_us, + "kernel_tflops": flops / (kernel_us * 1e-6) / 1e12, + "ref_tflops": flops / (ref_us * 1e-6) / 1e12, + "speedup": ref_us / kernel_us, + "kernel_max_err": diff_kernel, + "fused_max_err": diff_fused, + } + + +def print_bench_row(r): + print( + f" s1={r['seq_len']:>6} tiles={r['num_tiles']:>3} " + f"fa={r['kernel_us']:8.2f} us ({r['kernel_tflops']:7.3f} TFLOP/s) " + f"ref={r['ref_us']:8.2f} us ({r['ref_tflops']:7.3f} TFLOP/s) " + f"speedup={r['speedup']:.2f}x " + f"err: ours={r['kernel_max_err']:.2e} ref={r['fused_max_err']:.2e}" + ) + + +def plot_benchmark_results(results, out_png=None): + if not results: + return + + out_png = out_png or os.environ.get("FA_BENCH_PLOT_PATH", DEFAULT_PLOT_PATH) + + try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + print("Warning: matplotlib is not installed; skipping plot generation.") + return + + style_candidates = ("seaborn-v0_8-whitegrid", "seaborn-whitegrid") + for style_name in style_candidates: + try: + plt.style.use(style_name) + break + except OSError: + continue + + seq_lens = [r["seq_len"] for r in results] + fa_tflops = [r["kernel_tflops"] for r in results] + ref_tflops = [r["ref_tflops"] for r in results] + + fig, ax_thr = plt.subplots(figsize=(7, 5)) + fig.patch.set_facecolor("white") + + ax_thr.plot(seq_lens, fa_tflops, "o-", label="PTO flash attention (ir_ref fa.cpp)") + ax_thr.plot(seq_lens, ref_tflops, "s-", label="torch_npu fused attention") + ax_thr.set_title("Throughput") + ax_thr.set_xlabel("S1 sequence length") + ax_thr.set_ylabel("TFLOP/s") + ax_thr.legend() + ax_thr.set_xscale("log", base=2) + ax_thr.set_xticks(seq_lens) + ax_thr.set_xticklabels([str(x) for x in seq_lens], rotation=30) + + fig.suptitle( + f"Flash Attention Benchmark: Q={fb.Q_ROWS}, H={fb.HEAD}, " + f"S1_TILE={fb.S1_TILE}" + ) + fig.tight_layout() + + out_dir = os.path.dirname(out_png) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + fig.savefig(out_png, dpi=180) + plt.close(fig) + print(f"Saved benchmark plot: {out_png}") + + +def main(): + global fb + device = get_test_device() + torch.npu.set_device(device) + + bench_lengths = _parse_bench_lengths() + _apply_build_env_matching_seq(bench_lengths[0]) + fb = importlib.reload(fb) + + required = [(seq_len, num_tiles_for(seq_len)) for seq_len in bench_lengths] + for seq_len, nt in required: + require_lib(nt) + + _, first_nt = required[0] + default_lib = load_lib(require_lib(first_nt)) + test_flash(default_lib, device, num_tiles=first_nt) + + print(f"\n{'Benchmark (fa)':=^96}") + print( + f" Q_ROWS={fb.Q_ROWS} HEAD={fb.HEAD} " + f"S1_TILE={fb.S1_TILE} " + f"NUM_Q_BLOCKS={fb.NUM_Q_BLOCKS} cores={get_num_cube_cores()}" + ) + print(f" lengths: {list(bench_lengths)}") + print("-" * 96) + + results = [] + for seq_len, nt in required: + lib = load_lib(require_lib(nt)) + r = benchmark_flash(lib, device, num_tiles=nt) + print_bench_row(r) + results.append(r) + print("=" * 96) + + if os.environ.get("FA_BENCH_NO_PLOT", "").lower() not in ("1", "true", "yes"): + plot_benchmark_results(results) + + +if __name__ == "__main__": + main() diff --git a/examples/aot/flash_attention/split_pipe/DSL_FIX_TODOS.md b/examples/aot/flash_attention/split_pipe/DSL_FIX_TODOS.md new file mode 100644 index 00000000..9baaa4c8 --- /dev/null +++ b/examples/aot/flash_attention/split_pipe/DSL_FIX_TODOS.md @@ -0,0 +1,54 @@ +# DSL split_pipe kernel — fix backlog (vs `cpp_ref/split_pipe`) + +Goal: **`python run.py`** matches **`../cpp_ref/split_pipe/run.py`** for **correctness** (torch fp32 reference) and **performance** (TFLOP/s in the same ballpark as JIT `fa_performance_kernel.cpp`) on NPU. + +--- + +## Resolved (in builder / workflow) + +- [x] **A1 — VEC UB overflow from pipe `tpop` staging** — stacked layout under 192 KiB; `_VEC_UB_TAIL` assert. +- [x] **A2 — Cube RIGHT bank overflow (`HEAD=128`, `S1_TILE=512`)** — share K/V RIGHT at `RIGHT_KV_OFF=0`. +- [x] **A3 — Duplicate UB offset for `p_fp32` / `p_fp16`** — `VEC_P_FP16_OFF = VEC_P_FP32_OFF + _TILE_FP16_BYTES`; distinct `TASSIGN` in generated vec code. +- [x] **B3 — Rebuild `FA_TILES` variants** — `bash compile.sh` produces `fa.so` … `fa_128.so`. +- [x] **A9 — Compile/runtime env documentation** — `FA_Q_ROWS`, `FA_S1_TILE` wired through `compile.sh`; documented in `README.md`. +- [x] **A12 — Host/builder constant drift (NaNs from mismatched `FA_*`)** — `compile.sh` emits **`build_artifacts/fa${TAG}.build_env`** per variant; **`run.py`** applies the sidecar whose **`FA_NUM_TILES * FA_S1_TILE`** equals the first **`FA_BENCH_LENGTHS`** entry (tie-break: **`FA_S1_TILE`** env if set, else **`fa.build_env`**), then **`importlib.reload(fa_performance_builder)`** so GM/tensor sizes match the loaded `.so`. + +--- + +## Open — blocking default port + +- [ ] **A7 — CCU / aicore fault at default `S1_TILE=512`** + **`python run.py`** (default env aligned with `fa.so` + **A12**) still fails at **`torch.npu.synchronize()`** with **ACL 507015**, **CCU instruction address check** on some devices. **`../cpp_ref/split_pipe`** JIT passes on the **same** hardware. Needs bisheng/ptotas or vendor triage against **`build_artifacts/fa.cpp`**. + +- [ ] **A11 — Numerics on alternate tile width** + With **`FA_S1_TILE=256`** and **matching** compile/runtime **`FA_NUM_TILES` / `FA_Q_ROWS`**, the kernel **may run** without the A7 fault but **`torch.testing.assert_close`** can still fail (**NaNs** / large drift). Treat as schedule vs reference macro alignment or toolchain issue separate from A7. + +--- + +## Open — parity / audit (non-blocking for “runs at all”) + +- [ ] **A4 — Bisheng stack / spill** — stack `0x8000`; larger stacks rejected by the toolchain in this environment; not proved root cause of A7. + +- [ ] **A5 — Launch geometry** — reference uses `runTFA<<>>`, DSL uses `min(NUM_Q_BLOCKS, cores)` striping; totals consistent; optional CV audit. + +- [ ] **A6 — Structural parity** — reference: fused **`runTFA`**, **`QK_PRELOAD=4`**, **`CUBE_S0=128`**; DSL: **`pto.call(cube); pto.call(vec)`**, **`QK_PRELOAD=2`**, **`S0=32`**. Reference constructs **`TPipe`** in **QK → P → PV** order (`BUF0_QK_READY` / `BUF1_SM_READY` / `UPDATE_READY`); DSL initializes **QK → PV → P** (two **`l2g2l_pipe`** then V2C **`aic/aiv_initialize_pipe`**). Numerical match is the target once A7/A11 clear. + +- [ ] **A8 — CV tail sync vs `runTFA`** — reference ends cube with **`wait_flag_dev(CV_BLOCK_END)`** and vec with **`ffts_cross_core_sync(..., CV_BLOCK_END)`**; DSL output ends with **`ptoas_auto_sync_tail`**. May require future **`pto`/ptoas** hooks if parity demands explicit FFTS tail patterns. + +--- + +## Verification checklist + +- [ ] **B1 — `python run.py`** passes **`torch.testing.assert_close`** vs **`fa_reference`** for default + full **`FA_BENCH_LENGTHS`**. **Blocked by A7 / possibly A11** until NPU/toolchain issues clear. + +- [ ] **B2 — Benchmark vs `cpp_ref/split_pipe/run.py`** (`jit_compile_flash`), same FLOP model (`attn_flops_matmul_softmax_scale`), `flush_cache=False`. **Blocked by A7** for apples-to-apples default geometry. + +--- + +### Status log + +| Date | Change | +| ---- | ------ | +| (session) | Backlog created; A3; NPU CCU fault | +| 2026-04-29 | A3 verified in `fa.cpp`; `compile.sh`: **FA_Q_ROWS**, **FA_S1_TILE**; README status; default **512** path **A7**; **256** path **A11** | +| 2026-04-29 | **A12**: **`fa*.build_env`** + **`run.py`** **`reload`**; README pipe-order wording corrected (**QK→PV→P** vs ref **QK→P→PV**); removed obsolete “**flag_base**” resolved line | diff --git a/examples/aot/flash_attention/split_pipe/README.md b/examples/aot/flash_attention/split_pipe/README.md new file mode 100644 index 00000000..6ef44f3c --- /dev/null +++ b/examples/aot/flash_attention/split_pipe/README.md @@ -0,0 +1,88 @@ +# Flash Attention — split_pipe (PTO-DSL builder) + +Python builder `kernels/fa_performance_builder.py` emits MLIR → `ptoas` → C++ → `bisheng` `.so`, following the same **software-pipelined** FA schedule as the reference hand-written kernel in `../cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.cpp` (QK preload, softmax/GU overlap, `exp_max` ping-pong). The DSL uses explicit pipe helpers (`initialize_l2g2l_pipe`, `tpush`/`tpop` with `TILE_UP_DOWN`) instead of `TMPipe` + `TileSplitAxis` from the C++ file. + +## Current status (vs bundled C++ reference) + +| Check | Result | +| ----- | ------ | +| **`../cpp_ref/split_pipe/run.py`** (JIT `fa_performance_kernel.cpp`) | Passes correctness + benchmark on NPU in this environment | +| **`python run.py`** here — **default** (`S1_TILE=512`, `Q_ROWS=3072`, default `FA_NUM_TILES` matching `fa.so`) | **Fails** at `torch.npu.synchronize()` with **ACL 507015** / **CCU instruction address check** (aicore fault); numerics not validated | +| **`FA_S1_TILE=256`** + matching compile/runtime env | Kernel often **finishes** without that sync fault but **`torch.testing.assert_close` fails** (NaNs / large errors) — **not yet a validated port** | + +So the Python port does **not** yet **run correctly** nor **fully match** the C++ reference at the default cpp_ref-shaped geometry (`HEAD=128`, `tile_s1=512`, large Q). + +### Likely remaining gaps (see `DSL_FIX_TODOS.md`) + +1. **Compile/runtime env parity:** `FA_NUM_TILES`, `FA_S1_TILE`, and `FA_Q_ROWS` are fixed in the emitted MLIR. **`compile.sh`** writes **`build_artifacts/fa${TAG}.build_env`** per variant; **`run.py`** picks the file whose **`FA_NUM_TILES * FA_S1_TILE`** equals the first **`FA_BENCH_LENGTHS`** entry (tie-break: **`FA_S1_TILE`** env if set, else **`fa.build_env`**), then **`importlib.reload`**s **`fa_performance_builder`** so Python GM/tensor shapes match the loaded `.so`. You can still override via exported **`FA_*`** before launch if needed. +2. **FFTS / CV plumbing:** Hand-written `runTFA` constructs **`TPipe`** objects in **QK → P (V2C) → PV** order (`BUF0_QK_READY`, `BUF1_SM_READY`, `UPDATE_READY`). The DSL builder calls **`initialize_l2g2l_pipe`** for **QK** and **PV** first, then **`aic_initialize_pipe` / `aiv_initialize_pipe`** for **P** (**QK → PV → P**). Generated **`TPipe`** indices may therefore differ from the reference header ordering; **`ptoas --enable-insert-sync`** supplies multipipe synchronization. **Kernel-tail** **`wait_flag_dev` / `ffts_cross_core_sync`** vs **`ptoas_auto_sync_tail`** remains an **A8** structural gap (see **`DSL_FIX_TODOS.md`**). +3. **`S1_TILE=512` vec path:** Errors correlate with the widest softmax tiles (vs `examples/aot/flash_attention/experimental/` at `S1_TILE=256`, which passes). + +### Debug workflow (recommended) + +1. Regenerate and diff **`build_artifacts/fa.cpp`** against **`../cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.cpp`** for pipe IDs (`TPipe`), schedule phases, and sync tails. +2. After any builder change, **`bash compile.sh`** with the **same** `FA_*` env vars you will use for **`python run.py`**. +3. Compare behaviour with **`../cpp_ref/split_pipe`** on the **same** NPU. + +--- + +## Shapes & defaults + +**`HEAD=128`** and default **`S1_TILE=512`** match `../cpp_ref/split_pipe/run.py` (`test_flash(..., head=128)`, typical `tile_s1=512`). Cube rows per block stay **`S0=32`** (vector softmax tiles **`[S0_HALF, S1_TILE]`**); total Q rows default to **`3072` (`128 * 24`)** like the cpp_ref benchmark `s0`. Override with **`FA_Q_ROWS`** (must match at compile and run). + +Optional **`FA_S1_TILE`** (default `512`) is for experiments; changing it changes **`NUM_TILES`** needed for the same sequence length (`seq_len = NUM_TILES * S1_TILE`) and **must** be rebuilt. + +The reference C++ kernel uses larger cube blocks (`CUBE_S0=128`) with narrower logical vec rows; this DSL example keeps the geometry that fits explicit UB layouts (`S0=32`). + +--- + +## Setup + +```bash +cd /workdir/pto-dsl +pip install -e . +``` + +Environment (same as other `examples/aot` demos): + +- `ASCEND_TOOLKIT_HOME` +- `PTO_LIB_PATH` — directory that contains the `include/` tree with `` (repo root or `.../include`) + +--- + +## Build kernels + +From this directory: + +```bash +bash compile.sh +# Optional: FA_TILES=16,64 bash compile.sh +``` + +`compile.sh` passes through: + +- **`FA_NUM_TILES`** — per variant (from `FA_TILES`). +- **`FA_S1_TILE`** — default `512`. +- **`FA_Q_ROWS`** — default `3072`. + +Produces `build_artifacts/fa.so`, `fa_32.so`, … (each variant corresponds to `FA_NUM_TILES` baked into the emitted IR), plus matching **`fa.build_env`**, **`fa_32.build_env`**, … (`FA_NUM_TILES`, `FA_S1_TILE`, `FA_Q_ROWS`) for **`run.py`** auto-sync. + +**Example — 8k sequence with default tile width:** `NUM_TILES=16`, `S1_TILE=512` → `fa.so` + `fa.build_env`. + +--- + +## Correctness + benchmark + +```bash +python run.py +``` + +Override sequence lengths: `FA_BENCH_LENGTHS=8192,32768 python run.py`. + +**Important:** after **`compile.sh`**, **`run.py`** reads **`fa*.build_env`** for the first benchmark length so imported builder constants stay aligned with how that `.so` was built. If you rebuild with different **`FA_S1_TILE`** / **`FA_Q_ROWS`**, re-run **`compile.sh`** (or set matching **`FA_*`** env vars yourself). Ambiguous lengths (e.g. `8192 = 16×512 = 32×256`) require **`FA_S1_TILE`** in the environment or only one matching **`*.build_env`** on disk. + +--- + +## Compare generated C++ to reference + +Diff `build_artifacts/fa.cpp` (or `fa_.cpp`) against `../cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.cpp`: scheduling intent should align on pipeline phases; surface syntax differs (`pto.call` cube/vec vs fused `runTFA`, etc.). diff --git a/examples/aot/flash_attention/split_pipe/caller.cpp b/examples/aot/flash_attention/split_pipe/caller.cpp new file mode 100644 index 00000000..87546451 --- /dev/null +++ b/examples/aot/flash_attention/split_pipe/caller.cpp @@ -0,0 +1,32 @@ +#ifndef KERNEL_CPP +#error "KERNEL_CPP must be defined at compile time." +#endif + +#include + +extern "C" int rtGetC2cCtrlAddr(uint64_t *ctrlAddr, uint32_t *ctrlLen); + +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, + void *stream, + uint8_t *gmSlotBuffer, + uint8_t *q, + uint8_t *k, // K: [S1_TOTAL, HEAD] fp16 + uint8_t *v, + uint8_t *o) // output O: [Q_ROWS, HEAD] fp32 +{ + void *fftsAddr = nullptr; + uint32_t fftsLen = 0; + (void)rtGetC2cCtrlAddr(reinterpret_cast(&fftsAddr), &fftsLen); + (void)fftsLen; + + call_both<<>>( + (__gm__ int64_t *)fftsAddr, + (__gm__ float *)gmSlotBuffer, + (__gm__ half *)q, + (__gm__ half *)k, + (__gm__ half *)v, + (__gm__ float *)o); +} diff --git a/examples/aot/flash_attention/split_pipe/codegen_results/README.md b/examples/aot/flash_attention/split_pipe/codegen_results/README.md new file mode 100644 index 00000000..d877c7ab --- /dev/null +++ b/examples/aot/flash_attention/split_pipe/codegen_results/README.md @@ -0,0 +1,90 @@ +# split_pipe — archived PTO codegen (best attempt vs reference C++) + +This folder holds a **frozen snapshot** of the closest match produced today between: + +1. **PTO-DSL builder output** (`fa.pto.mlir`) — MLIR emitted by `kernels/fa_performance_builder.py`. +2. **ptoas-generated Ascend C++** (`fa.ptoas.generated.cpp`) — lowered from that MLIR with **`ptoas --pto-arch=a3 --pto-level=level3 --enable-insert-sync`**. +3. **Hand-written reference kernel** (`fa_performance_kernel.reference.cpp`) — copy of `../cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.cpp` for side‑by‑side comparison. + +“Best attempt” here means the **default split_pipe product**: **`FA_NUM_TILES=16`**, **`FA_S1_TILE=512`**, **`FA_Q_ROWS=3072`** — i.e. the **`fa.so`** / **`fa.mlir`** path aligned with the bundled **`compile.sh`** defaults and the cpp_ref benchmark geometry documented in `../README.md` (**HEAD=128**, wide softmax tile). + +--- + +## How these files were generated + +From `split_pipe/`: + +```bash +# Rebuild only the plain fa.so variant (NUM_TILES=16 → plain filenames). +FA_TILES=16 FA_S1_TILE=512 FA_Q_ROWS=3072 bash compile.sh +``` + +Pipeline (same as `compile.sh`): + +| Step | Command | Output | +|------|---------|--------| +| 1 | `FA_NUM_TILES=16 FA_S1_TILE=512 FA_Q_ROWS=3072 python kernels/fa_performance_builder.py` | stdout redirected to `build_artifacts/fa.mlir` | +| 2 | `ptoas --pto-arch=a3 --pto-level=level3 --enable-insert-sync build_artifacts/fa.mlir` | stdout redirected to `build_artifacts/fa.cpp` | +| 3 | `bisheng … caller.cpp -o build_artifacts/fa.so` | linked host shim + embedded kernel C++ (not duplicated here) | + +Files in **this directory** are copies of steps **1–2** plus the reference source (run from **`split_pipe/`**, i.e. the parent of `codegen_results/`): + +```bash +cp build_artifacts/fa.mlir codegen_results/fa.pto.mlir +cp build_artifacts/fa.cpp codegen_results/fa.ptoas.generated.cpp +cp ../cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.cpp \ + codegen_results/fa_performance_kernel.reference.cpp +``` + +Canonical reference path from **`split_pipe/codegen_results/`** alone: + +`../../cpp_ref/split_pipe/kernels/flash_atten/fa_performance_kernel.cpp` + +**Naming:** The DSL emits **MLIR** (`module { … func.func @cube_kernel … }`). There is no separate `.pto` file extension in this repo; **`fa.pto.mlir`** is the **PTO-DSL MLIR IR** artifact users diff against `ptoas`. + +--- + +## Gaps: `fa.ptoas.generated.cpp` vs `fa_performance_kernel.reference.cpp` + +Both ultimately include **`pto/pto-inst.hpp`** and lower to **`TPipe` / `TLOAD` / `TMATMUL` / vec ops**, but structure and intent differ substantially. + +### 1. Kernel shape and entry points + +| Aspect | Reference | DSL → ptoas | +|--------|-----------|-------------| +| Top-level | Template **`runTFA<…>`** fuses QK / softmax (**`compute_p`**) / PV / GU (**`compute_gu`**) in one TU | **`cube_kernel`** (QK + PV matmuls and FIFO I/O) and **`vector_kernel`** (softmax + GU side) as **separate `AICORE` functions**, plus **`call_both`** wrapper | +| Launch API | **`LaunchTFA<…>`** with many GM FIFO / profiling pointers (`fa_performance_kernel.h`) | **`call_both`** (`caller.cpp`) passes **`gm_slot`** scratch + Q/K/V/O tensors | + +The reference is optimized around **macro helpers** (`pto_macro_matmul.hpp`, `pto_macro_fa_softmax.hpp`, `pto_macro_fa_gu.hpp`). The DSL path expresses the same *logical* pipeline via **`pto.*` MLIR ops**, which **ptoas** expands into long SSA-style **`Tile`/`GlobalTensor`** code without those FA macros. + +### 2. Compile-time geometry + +Reference defaults (see `fa_performance_kernel.h`) include **`kFaQkPreload = 4`**, **`kFaTileS1 = 256`**, **`kFaCubeS1 = 128`**, while **`LaunchTFA`** is instantiated from cpp_ref `run.py`/JIT with parameters that may override **`TILE_S1`**. + +This archived DSL build fixes **`S0 = 32`** (vec half-rows **`16`** per sub-block), **`S1_TILE = 512`**, **`HEAD = 128`**, **`NUM_TILES = 16`**, **`QK_PRELOAD = 2`** in `fa_performance_builder.py`. So numerically “closest” is about **matching workload size** (e.g. **8192** sequence length), **not** identical template parameters to every **`kFa*`** constant in the header. + +### 3. FFTS pipes and flag namespaces + +Reference constructs **`TPipe`**, **`TPipe`**, **`TPipe`** using **`FftsBufferFlag`** (`BUF0_QK_READY`, …). + +ptoas emits **`TPipe<0, Direction::DIR_C2V, …>`**, **`TPipe<2, …>`**, **`TPipe<4, Direction::DIR_V2C, …>`** (see generated file near **`cube_kernel`**). The **numeric first template arguments (0 / 2 / 4)** line up with **FFT slot ordering** chosen by the DSL **`initialize_l2g2l_pipe` / `aic_initialize_pipe`** sequence — **not** identical to the reference’s **QK → P → PV** `TPipe` declaration order (DSL initializes **QK → PV → P**). + +### 4. Synchronization and tail behavior + +Reference **`runTFA`** ends with explicit **CV / FFTS** coordination (**`wait_flag_dev(CV_BLOCK_END)`**, **`ffts_cross_core_sync`**, etc., conditioned on **`DAV_CUBE`/`DAV_VEC`**). + +Generated code finishes vector/cube sections with **`ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll)`**, which lowers to **`pipe_barrier(PIPE_ALL)`** (see top of `fa.ptoas.generated.cpp`). That is the **`ptoas --enable-insert-sync`** story — simpler and **not** a line‑by‑line match to the reference tail. + +### 5. Dependencies and portability + +Reference pulls **`acl`**, **`Pto_prefetch`**, **`TSyncCVID`**, FA‑specific macro headers. + +Generated kernel code is intentionally minimal (**`pto-inst.hpp` + auto-sync helper**). Host linkage still uses **`caller.cpp`** / **`set_ffts_base_addr`** in the full build. + +--- + +## Using this archive + +- Diff MLIR iterations: `diff -u fa.pto.mlir …` +- Diff lowered C++ vs reference: `diff -u fa.ptoas.generated.cpp fa_performance_kernel.reference.cpp` (expect **large** differences — use the table above to interpret). +- After changing **`fa_performance_builder.py`**, regenerate **both** `build_artifacts/` and **these snapshots** so documentation stays reproducible. diff --git a/examples/aot/flash_attention/split_pipe/codegen_results/fa.pto.mlir b/examples/aot/flash_attention/split_pipe/codegen_results/fa.pto.mlir new file mode 100644 index 00000000..cf00432a --- /dev/null +++ b/examples/aot/flash_attention/split_pipe/codegen_results/fa.pto.mlir @@ -0,0 +1,401 @@ +module { + func.func @cube_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c128 = arith.constant 128 : index + %c512 = arith.constant 512 : index + %c8192 = arith.constant 8192 : index + %c16 = arith.constant 16 : index + %c96 = arith.constant 96 : index + %0 = pto.get_block_num + %1 = arith.index_cast %0 : i64 to index + %2 = pto.get_block_idx + %3 = arith.index_cast %2 : i64 to index + %4 = arith.divsi %c96, %1 : index + %5 = arith.remsi %c96, %1 : index + %6 = arith.addi %4, %c1 : index + %7 = arith.muli %3, %6 : index + %8 = arith.addi %4, %c1 : index + %9 = arith.muli %5, %8 : index + %10 = arith.subi %3, %5 : index + %11 = arith.muli %10, %4 : index + %12 = arith.addi %9, %11 : index + %13 = arith.cmpi slt, %3, %5 : index + %14 = arith.select %13, %7, %12 : index + %15 = arith.cmpi slt, %3, %5 : index + %16 = arith.addi %4, %c1 : index + %17 = arith.select %15, %16, %4 : index + %18 = arith.addi %14, %17 : index + %c229376 = arith.constant 229376 : index + %19 = arith.muli %3, %c229376 : index + %20 = pto.addptr %arg0, %19 : -> + %c0_0 = arith.constant 0 : index + %21 = pto.addptr %20, %c0_0 : -> + %c131072 = arith.constant 131072 : index + %22 = pto.addptr %20, %c131072 : -> + %c163840 = arith.constant 163840 : index + %23 = pto.addptr %20, %c163840 : -> + %24 = pto.import_reserved_buffer{name = "fa_qk_c2v_fifo", peer_func = @vector_kernel} -> i32 + %25 = pto.initialize_l2g2l_pipe{dir_mask = 1, slot_size = 65536, slot_num = 8, local_slot_num = 1} (%21 : !pto.ptr, %24 : i32) -> !pto.pipe + %26 = pto.import_reserved_buffer{name = "fa_pv_c2v_fifo", peer_func = @vector_kernel} -> i32 + %27 = pto.initialize_l2g2l_pipe{dir_mask = 1, slot_size = 16384, slot_num = 8, local_slot_num = 1} (%22 : !pto.ptr, %26 : i32) -> !pto.pipe + %28 = pto.reserve_buffer{name = "fa_p_v2c_fifo", size = 262144, location = , auto = false, base = 327680} -> i32 + %c0_i32 = arith.constant 0 : i32 + pto.aic_initialize_pipe {id = 30, dir_mask = 2, slot_size = 32768, nosplit = false}(gm_slot_buffer = %23 : !pto.ptr, c2v_consumer_buf = %c0_i32 : i32, v2c_consumer_buf = %28 : i32) + %c0_i64 = arith.constant 0 : i64 + %c0_i64_1 = arith.constant 0 : i64 + %29 = pto.alloc_tile addr = %c0_i64_1 : !pto.tile_buf + %c0_i64_2 = arith.constant 0 : i64 + %30 = pto.alloc_tile addr = %c0_i64_2 : !pto.tile_buf + %c8192_i64 = arith.constant 8192 : i64 + %31 = pto.alloc_tile addr = %c8192_i64 : !pto.tile_buf + %32 = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %c0_i64_3 = arith.constant 0 : i64 + %33 = pto.alloc_tile addr = %c0_i64_3 : !pto.tile_buf + %c139264_i64 = arith.constant 139264 : i64 + %34 = pto.alloc_tile addr = %c139264_i64 : !pto.tile_buf + %c8192_i64_4 = arith.constant 8192 : i64 + %35 = pto.alloc_tile addr = %c8192_i64_4 : !pto.tile_buf + %c172032_i64 = arith.constant 172032 : i64 + %36 = pto.alloc_tile addr = %c172032_i64 : !pto.tile_buf + %37 = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %c65536_i64 = arith.constant 65536 : i64 + %38 = pto.alloc_tile addr = %c65536_i64 : !pto.tile_buf + %c3072 = arith.constant 3072 : index + %39 = pto.make_tensor_view %arg1, shape = [%c3072, %c128], strides = [%c128, %c1] : !pto.tensor_view + %40 = pto.make_tensor_view %arg2, shape = [%c128, %c8192], strides = [%c1, %c128] : !pto.tensor_view + %41 = pto.make_tensor_view %arg3, shape = [%c8192, %c128], strides = [%c128, %c1] : !pto.tensor_view + scf.for %arg4 = %14 to %18 step %c1 { + %42 = arith.muli %arg4, %c32 : index + %43 = pto.partition_view %39, offsets = [%42, %c0], sizes = [%c32, %c128] : !pto.tensor_view + pto.tload ins(%43 : !pto.partition_tensor_view<32x128xf16>) outs(%29 : !pto.tile_buf) + pto.tmov ins(%29 : !pto.tile_buf) outs(%30 : !pto.tile_buf) + %c0_5 = arith.constant 0 : index + %44 = pto.partition_view %40, offsets = [%c0, %c0_5], sizes = [%c128, %c512] : !pto.tensor_view + pto.tload ins(%44 : !pto.partition_tensor_view<128x512xf16>) outs(%31 : !pto.tile_buf) + pto.tmov ins(%31 : !pto.tile_buf) outs(%32 : !pto.tile_buf) + pto.tmatmul ins(%30, %32 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tpush(%33, %25 : !pto.tile_buf, !pto.pipe) {split = 1} + %c512_6 = arith.constant 512 : index + %45 = pto.partition_view %40, offsets = [%c0, %c512_6], sizes = [%c128, %c512] : !pto.tensor_view + pto.tload ins(%45 : !pto.partition_tensor_view<128x512xf16>) outs(%31 : !pto.tile_buf) + pto.tmov ins(%31 : !pto.tile_buf) outs(%32 : !pto.tile_buf) + pto.tmatmul ins(%30, %32 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tpush(%33, %25 : !pto.tile_buf, !pto.pipe) {split = 1} + %46 = pto.partition_view %41, offsets = [%c0, %c0], sizes = [%c512, %c128] : !pto.tensor_view + pto.tload ins(%46 : !pto.partition_tensor_view<512x128xf16>) outs(%36 : !pto.tile_buf) + %c7 = arith.constant 7 : index + scf.for %arg5 = %c0 to %c7 step %c1 { + %50 = arith.muli %arg5, %c2 : index + %c2_7 = arith.constant 2 : index + %51 = arith.addi %50, %c2_7 : index + %52 = arith.muli %51, %c512 : index + %53 = pto.partition_view %40, offsets = [%c0, %52], sizes = [%c128, %c512] : !pto.tensor_view + pto.tload ins(%53 : !pto.partition_tensor_view<128x512xf16>) outs(%31 : !pto.tile_buf) + %54 = pto.tpop_from_aiv {id = 30, split = 1} -> !pto.tile_buf + pto.tmov ins(%54 : !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aiv {id = 30, split = 1} + pto.tmov ins(%36 : !pto.tile_buf) outs(%37 : !pto.tile_buf) + %55 = arith.addi %50, %c1 : index + %56 = arith.muli %55, %c512 : index + %57 = pto.partition_view %41, offsets = [%56, %c0], sizes = [%c512, %c128] : !pto.tensor_view + pto.tload ins(%57 : !pto.partition_tensor_view<512x128xf16>) outs(%36 : !pto.tile_buf) + pto.tmatmul ins(%35, %37 : !pto.tile_buf, !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tpush(%38, %27 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.tmov ins(%31 : !pto.tile_buf) outs(%32 : !pto.tile_buf) + pto.tmatmul ins(%30, %32 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tpush(%33, %25 : !pto.tile_buf, !pto.pipe) {split = 1} + %58 = arith.muli %arg5, %c2 : index + %59 = arith.addi %58, %c1 : index + %c2_8 = arith.constant 2 : index + %60 = arith.addi %59, %c2_8 : index + %61 = arith.muli %60, %c512 : index + %62 = pto.partition_view %40, offsets = [%c0, %61], sizes = [%c128, %c512] : !pto.tensor_view + pto.tload ins(%62 : !pto.partition_tensor_view<128x512xf16>) outs(%31 : !pto.tile_buf) + %63 = pto.tpop_from_aiv {id = 30, split = 1} -> !pto.tile_buf + pto.tmov ins(%63 : !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aiv {id = 30, split = 1} + pto.tmov ins(%36 : !pto.tile_buf) outs(%37 : !pto.tile_buf) + %64 = arith.addi %59, %c1 : index + %65 = arith.muli %64, %c512 : index + %66 = pto.partition_view %41, offsets = [%65, %c0], sizes = [%c512, %c128] : !pto.tensor_view + pto.tload ins(%66 : !pto.partition_tensor_view<512x128xf16>) outs(%36 : !pto.tile_buf) + pto.tmatmul ins(%35, %37 : !pto.tile_buf, !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tpush(%38, %27 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.tmov ins(%31 : !pto.tile_buf) outs(%32 : !pto.tile_buf) + pto.tmatmul ins(%30, %32 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tpush(%33, %25 : !pto.tile_buf, !pto.pipe) {split = 1} + } + %47 = pto.tpop_from_aiv {id = 30, split = 1} -> !pto.tile_buf + pto.tmov ins(%47 : !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aiv {id = 30, split = 1} + pto.tmov ins(%36 : !pto.tile_buf) outs(%37 : !pto.tile_buf) + %c7680 = arith.constant 7680 : index + %48 = pto.partition_view %41, offsets = [%c7680, %c0], sizes = [%c512, %c128] : !pto.tensor_view + pto.tload ins(%48 : !pto.partition_tensor_view<512x128xf16>) outs(%36 : !pto.tile_buf) + pto.tmatmul ins(%35, %37 : !pto.tile_buf, !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tpush(%38, %27 : !pto.tile_buf, !pto.pipe) {split = 1} + %49 = pto.tpop_from_aiv {id = 30, split = 1} -> !pto.tile_buf + pto.tmov ins(%49 : !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree_from_aiv {id = 30, split = 1} + pto.tmov ins(%36 : !pto.tile_buf) outs(%37 : !pto.tile_buf) + pto.tmatmul ins(%35, %37 : !pto.tile_buf, !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tpush(%38, %27 : !pto.tile_buf, !pto.pipe) {split = 1} + } + return + } + func.func @vector_kernel(%arg0: !pto.ptr, %arg1: !pto.ptr) attributes {pto.kernel_kind = #pto.kernel_kind} { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c32 = arith.constant 32 : index + %c16 = arith.constant 16 : index + %c128 = arith.constant 128 : index + %c16_0 = arith.constant 16 : index + %c96 = arith.constant 96 : index + %0 = pto.get_block_num + %1 = arith.index_cast %0 : i64 to index + %2 = pto.get_block_idx + %3 = arith.index_cast %2 : i64 to index + %4 = arith.divsi %c96, %1 : index + %5 = arith.remsi %c96, %1 : index + %6 = arith.addi %4, %c1 : index + %7 = arith.muli %3, %6 : index + %8 = arith.addi %4, %c1 : index + %9 = arith.muli %5, %8 : index + %10 = arith.subi %3, %5 : index + %11 = arith.muli %10, %4 : index + %12 = arith.addi %9, %11 : index + %13 = arith.cmpi slt, %3, %5 : index + %14 = arith.select %13, %7, %12 : index + %15 = arith.cmpi slt, %3, %5 : index + %16 = arith.addi %4, %c1 : index + %17 = arith.select %15, %16, %4 : index + %18 = arith.addi %14, %17 : index + %c229376 = arith.constant 229376 : index + %19 = arith.muli %3, %c229376 : index + %20 = pto.addptr %arg0, %19 : -> + %c0_1 = arith.constant 0 : index + %21 = pto.addptr %20, %c0_1 : -> + %c131072 = arith.constant 131072 : index + %22 = pto.addptr %20, %c131072 : -> + %c163840 = arith.constant 163840 : index + %23 = pto.addptr %20, %c163840 : -> + %24 = pto.reserve_buffer{name = "fa_qk_c2v_fifo", size = 65536, location = , auto = false, base = 0} -> i32 + %25 = pto.initialize_l2g2l_pipe{dir_mask = 1, slot_size = 65536, slot_num = 8, local_slot_num = 1} (%21 : !pto.ptr, %24 : i32) -> !pto.pipe + %26 = pto.reserve_buffer{name = "fa_pv_c2v_fifo", size = 16384, location = , auto = false, base = 65536} -> i32 + %27 = pto.initialize_l2g2l_pipe{dir_mask = 1, slot_size = 16384, slot_num = 8, local_slot_num = 1} (%22 : !pto.ptr, %26 : i32) -> !pto.pipe + %28 = pto.import_reserved_buffer{name = "fa_p_v2c_fifo", peer_func = @cube_kernel} -> i32 + %c0_i32 = arith.constant 0 : i32 + pto.aiv_initialize_pipe {id = 30, dir_mask = 2, slot_size = 32768, nosplit = false}(gm_slot_buffer = %23 : !pto.ptr, c2v_consumer_buf = %c0_i32 : i32, v2c_consumer_buf = %28 : i32) + %29 = pto.get_subblock_idx + %30 = arith.index_cast %29 : i64 to index + %31 = arith.muli %30, %c16 : index + %c114688_i64 = arith.constant 114688 : i64 + %32 = pto.alloc_tile addr = %c114688_i64 : !pto.tile_buf + %c147456_i64 = arith.constant 147456 : i64 + %33 = pto.alloc_tile addr = %c147456_i64 : !pto.tile_buf + %c163840_i64 = arith.constant 163840 : i64 + %34 = pto.alloc_tile addr = %c163840_i64 : !pto.tile_buf + %c180224_i64 = arith.constant 180224 : i64 + %35 = pto.alloc_tile addr = %c180224_i64 : !pto.tile_buf + %c188416_i64 = arith.constant 188416 : i64 + %36 = pto.alloc_tile addr = %c188416_i64 : !pto.tile_buf + %c188480_i64 = arith.constant 188480 : i64 + %37 = pto.alloc_tile addr = %c188480_i64 : !pto.tile_buf + %c188544_i64 = arith.constant 188544 : i64 + %38 = pto.alloc_tile addr = %c188544_i64 : !pto.tile_buf + %c188608_i64 = arith.constant 188608 : i64 + %39 = pto.alloc_tile addr = %c188608_i64 : !pto.tile_buf + %c188672_i64 = arith.constant 188672 : i64 + %40 = pto.alloc_tile addr = %c188672_i64 : !pto.tile_buf + %c188736_i64 = arith.constant 188736 : i64 + %41 = pto.alloc_tile addr = %c188736_i64 : !pto.tile_buf + %cst = arith.constant 0.0883883461 : f32 + %cst_2 = arith.constant 1.000000e+00 : f32 + %c3072 = arith.constant 3072 : index + %42 = pto.make_tensor_view %arg1, shape = [%c3072, %c128], strides = [%c128, %c1] : !pto.tensor_view + scf.for %arg2 = %14 to %18 step %c1 { + %43 = arith.muli %arg2, %c32 : index + %c81920_i64 = arith.constant 81920 : i64 + %44 = pto.alloc_tile addr = %c81920_i64 : !pto.tile_buf + pto.tpop(%44, %25 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.tmuls ins(%44, %cst : !pto.tile_buf, f32) outs(%44 : !pto.tile_buf) + pto.trowmax ins(%44, %32 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %45 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %46 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %47 = pto.treshape %40 : !pto.tile_buf -> !pto.tile_buf + %48 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %49 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.trowexpandsub ins(%44, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmuls ins(%45, %cst_2 : !pto.tile_buf, f32) outs(%46 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%38 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree(%25 : !pto.pipe) {split = 1} + %c81920_i64_3 = arith.constant 81920 : i64 + %50 = pto.alloc_tile addr = %c81920_i64_3 : !pto.tile_buf + pto.tpop(%50, %25 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.tmuls ins(%50, %cst : !pto.tile_buf, f32) outs(%50 : !pto.tile_buf) + pto.trowmax ins(%50, %32 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %51 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %52 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %53 = pto.treshape %41 : !pto.tile_buf -> !pto.tile_buf + %54 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %55 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%51, %52 : !pto.tile_buf, !pto.tile_buf) outs(%51 : !pto.tile_buf) + pto.tsub ins(%52, %51 : !pto.tile_buf, !pto.tile_buf) outs(%53 : !pto.tile_buf) + pto.tmuls ins(%51, %cst_2 : !pto.tile_buf, f32) outs(%52 : !pto.tile_buf) + pto.trowexpandsub ins(%50, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%53 : !pto.tile_buf) outs(%53 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%54, %53 : !pto.tile_buf, !pto.tile_buf) outs(%54 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.tadd ins(%54, %55 : !pto.tile_buf, !pto.tile_buf) outs(%54 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree(%25 : !pto.pipe) {split = 1} + %c114688_i64_4 = arith.constant 114688 : i64 + %56 = pto.alloc_tile addr = %c114688_i64_4 : !pto.tile_buf + pto.tpop(%56, %27 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.tmov ins(%56 : !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree(%27 : !pto.pipe) {split = 1} + %c81920_i64_5 = arith.constant 81920 : i64 + %57 = pto.alloc_tile addr = %c81920_i64_5 : !pto.tile_buf + pto.tpop(%57, %25 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.tmuls ins(%57, %cst : !pto.tile_buf, f32) outs(%57 : !pto.tile_buf) + pto.trowmax ins(%57, %32 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %58 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %59 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %60 = pto.treshape %40 : !pto.tile_buf -> !pto.tile_buf + %61 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %62 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%58, %59 : !pto.tile_buf, !pto.tile_buf) outs(%58 : !pto.tile_buf) + pto.tsub ins(%59, %58 : !pto.tile_buf, !pto.tile_buf) outs(%60 : !pto.tile_buf) + pto.tmuls ins(%58, %cst_2 : !pto.tile_buf, f32) outs(%59 : !pto.tile_buf) + pto.trowexpandsub ins(%57, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%60 : !pto.tile_buf) outs(%60 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%61, %60 : !pto.tile_buf, !pto.tile_buf) outs(%61 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.tadd ins(%61, %62 : !pto.tile_buf, !pto.tile_buf) outs(%61 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree(%25 : !pto.pipe) {split = 1} + %c114688_i64_6 = arith.constant 114688 : i64 + %63 = pto.alloc_tile addr = %c114688_i64_6 : !pto.tile_buf + pto.tpop(%63, %27 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.trowexpandmul ins(%35, %41 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %63 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree(%27 : !pto.pipe) {split = 1} + %c81920_i64_7 = arith.constant 81920 : i64 + %64 = pto.alloc_tile addr = %c81920_i64_7 : !pto.tile_buf + pto.tpop(%64, %25 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.tmuls ins(%64, %cst : !pto.tile_buf, f32) outs(%64 : !pto.tile_buf) + pto.trowmax ins(%64, %32 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %65 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %66 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %67 = pto.treshape %41 : !pto.tile_buf -> !pto.tile_buf + %68 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %69 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%65, %66 : !pto.tile_buf, !pto.tile_buf) outs(%65 : !pto.tile_buf) + pto.tsub ins(%66, %65 : !pto.tile_buf, !pto.tile_buf) outs(%67 : !pto.tile_buf) + pto.tmuls ins(%65, %cst_2 : !pto.tile_buf, f32) outs(%66 : !pto.tile_buf) + pto.trowexpandsub ins(%64, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%67 : !pto.tile_buf) outs(%67 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%68, %67 : !pto.tile_buf, !pto.tile_buf) outs(%68 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.tadd ins(%68, %69 : !pto.tile_buf, !pto.tile_buf) outs(%68 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree(%25 : !pto.pipe) {split = 1} + %c7 = arith.constant 7 : index + scf.for %arg3 = %c1 to %c7 step %c1 { + %c114688_i64_10 = arith.constant 114688 : i64 + %74 = pto.alloc_tile addr = %c114688_i64_10 : !pto.tile_buf + pto.tpop(%74, %27 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.trowexpandmul ins(%35, %40 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %74 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree(%27 : !pto.pipe) {split = 1} + %c81920_i64_11 = arith.constant 81920 : i64 + %75 = pto.alloc_tile addr = %c81920_i64_11 : !pto.tile_buf + pto.tpop(%75, %25 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.tmuls ins(%75, %cst : !pto.tile_buf, f32) outs(%75 : !pto.tile_buf) + pto.trowmax ins(%75, %32 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %76 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %77 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %78 = pto.treshape %40 : !pto.tile_buf -> !pto.tile_buf + %79 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %80 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%76, %77 : !pto.tile_buf, !pto.tile_buf) outs(%76 : !pto.tile_buf) + pto.tsub ins(%77, %76 : !pto.tile_buf, !pto.tile_buf) outs(%78 : !pto.tile_buf) + pto.tmuls ins(%76, %cst_2 : !pto.tile_buf, f32) outs(%77 : !pto.tile_buf) + pto.trowexpandsub ins(%75, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%78 : !pto.tile_buf) outs(%78 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%79, %78 : !pto.tile_buf, !pto.tile_buf) outs(%79 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.tadd ins(%79, %80 : !pto.tile_buf, !pto.tile_buf) outs(%79 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree(%25 : !pto.pipe) {split = 1} + %c114688_i64_12 = arith.constant 114688 : i64 + %81 = pto.alloc_tile addr = %c114688_i64_12 : !pto.tile_buf + pto.tpop(%81, %27 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.trowexpandmul ins(%35, %41 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %81 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree(%27 : !pto.pipe) {split = 1} + %c81920_i64_13 = arith.constant 81920 : i64 + %82 = pto.alloc_tile addr = %c81920_i64_13 : !pto.tile_buf + pto.tpop(%82, %25 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.tmuls ins(%82, %cst : !pto.tile_buf, f32) outs(%82 : !pto.tile_buf) + pto.trowmax ins(%82, %32 : !pto.tile_buf, !pto.tile_buf) outs(%37 : !pto.tile_buf) + %83 = pto.treshape %37 : !pto.tile_buf -> !pto.tile_buf + %84 = pto.treshape %36 : !pto.tile_buf -> !pto.tile_buf + %85 = pto.treshape %41 : !pto.tile_buf -> !pto.tile_buf + %86 = pto.treshape %38 : !pto.tile_buf -> !pto.tile_buf + %87 = pto.treshape %39 : !pto.tile_buf -> !pto.tile_buf + pto.tmax ins(%83, %84 : !pto.tile_buf, !pto.tile_buf) outs(%83 : !pto.tile_buf) + pto.tsub ins(%84, %83 : !pto.tile_buf, !pto.tile_buf) outs(%85 : !pto.tile_buf) + pto.tmuls ins(%83, %cst_2 : !pto.tile_buf, f32) outs(%84 : !pto.tile_buf) + pto.trowexpandsub ins(%82, %37 : !pto.tile_buf, !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.texp ins(%85 : !pto.tile_buf) outs(%85 : !pto.tile_buf) + pto.texp ins(%33 : !pto.tile_buf) outs(%33 : !pto.tile_buf) + pto.tmul ins(%86, %85 : !pto.tile_buf, !pto.tile_buf) outs(%86 : !pto.tile_buf) + pto.trowsum ins(%33, %32 : !pto.tile_buf, !pto.tile_buf) outs(%39 : !pto.tile_buf) + pto.tadd ins(%86, %87 : !pto.tile_buf, !pto.tile_buf) outs(%86 : !pto.tile_buf) + pto.tcvt ins(%33 {rmode = #pto} : !pto.tile_buf) outs(%34 : !pto.tile_buf) + pto.tpush_to_aic(%34 : !pto.tile_buf) {id = 30, split = 1} + pto.tfree(%25 : !pto.pipe) {split = 1} + } + %c114688_i64_8 = arith.constant 114688 : i64 + %70 = pto.alloc_tile addr = %c114688_i64_8 : !pto.tile_buf + pto.tpop(%70, %27 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.trowexpandmul ins(%35, %40 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %70 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree(%27 : !pto.pipe) {split = 1} + %c114688_i64_9 = arith.constant 114688 : i64 + %71 = pto.alloc_tile addr = %c114688_i64_9 : !pto.tile_buf + pto.tpop(%71, %27 : !pto.tile_buf, !pto.pipe) {split = 1} + pto.trowexpandmul ins(%35, %41 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tadd ins(%35, %71 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + pto.tfree(%27 : !pto.pipe) {split = 1} + pto.trowexpanddiv ins(%35, %38 : !pto.tile_buf, !pto.tile_buf) outs(%35 : !pto.tile_buf) + %72 = arith.addi %43, %31 : index + %73 = pto.partition_view %42, offsets = [%72, %c0], sizes = [%c16, %c128] : !pto.tensor_view + pto.tstore ins(%35 : !pto.tile_buf) outs(%73 : !pto.partition_tensor_view<16x128xf32>) + } + return + } + func.func @call_both(%arg0: memref<256xi64>, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: !pto.ptr, %arg4: !pto.ptr, %arg5: !pto.ptr) attributes {pto.entry} { + pto.set_ffts %arg0 : memref<256xi64> + call @cube_kernel(%arg1, %arg2, %arg3, %arg4) : (!pto.ptr, !pto.ptr, !pto.ptr, !pto.ptr) -> () + call @vector_kernel(%arg1, %arg5) : (!pto.ptr, !pto.ptr) -> () + return + } +} + diff --git a/examples/aot/flash_attention/split_pipe/codegen_results/fa.ptoas.generated.cpp b/examples/aot/flash_attention/split_pipe/codegen_results/fa.ptoas.generated.cpp new file mode 100644 index 00000000..87bb5aae --- /dev/null +++ b/examples/aot/flash_attention/split_pipe/codegen_results/fa.ptoas.generated.cpp @@ -0,0 +1,650 @@ +#include "pto/pto-inst.hpp" +using namespace pto; + +enum class PTOAutoSyncTailMode : int { + kBarrierAll = 0, + kSetWaitMte3ToSEvent0 = 1, +}; + +static AICORE inline void ptoas_auto_sync_tail( + PTOAutoSyncTailMode mode = PTOAutoSyncTailMode::kBarrierAll) { + switch (mode) { + case PTOAutoSyncTailMode::kSetWaitMte3ToSEvent0: + set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + break; + case PTOAutoSyncTailMode::kBarrierAll: + default: + pipe_barrier(PIPE_ALL); + break; + } +} + +AICORE void cube_kernel(__gm__ float* v1, __gm__ half* v2, __gm__ half* v3, __gm__ half* v4) { + unsigned v5 = 7680; + unsigned v6 = 512; + unsigned v7 = 0; + const int32_t v8 = 96; + const int32_t v9 = 512; + const int32_t v10 = 128; + const int32_t v11 = 32; + const int32_t v12 = 2; + const int32_t v13 = 1; + const int32_t v14 = 229376; + const int32_t v15 = 131072; + const int32_t v16 = 163840; + const int64_t v17 = 0; + const int64_t v18 = 8192; + const int64_t v19 = 172032; + const int64_t v20 = 65536; + const int32_t v21 = 7; + const int32_t v22 = 0; + const int32_t v23 = 65536; + const int32_t v24 = 327680; + using T = float; + + #if defined(__DAV_CUBE__) + size_t v25 = (size_t) v13; + int64_t v26 = get_block_num(); + int32_t v27 = (int32_t) ((int64_t) v26); + int64_t v28 = get_block_idx(); + int32_t v29 = (int32_t) ((int64_t) v28); + int32_t v30 = v8 / v27; + int32_t v31 = v8 % v27; + int32_t v32 = (int32_t) ((uint32_t) v30 + (uint32_t) v13); + bool v33 = v29 < v31; + int32_t v34 = v33 ? (int32_t) ((uint32_t) v29 * (uint32_t) v32) : (int32_t) ((uint32_t) ((int32_t) (uint32_t) v31 * (uint32_t) v32) + (uint32_t) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v29 - (uint32_t) v31) * (uint32_t) v30)); + int32_t v35 = (int32_t) ((uint32_t) v29 * (uint32_t) v14); + __gm__ float* v36 = v1 + v35; + + int32_t v37 = (int32_t) ((uint32_t) v35 + (uint32_t) v15); + __gm__ float* v38 = v1 + v37; + + int32_t v39 = (int32_t) ((uint32_t) v35 + (uint32_t) v16); + __gm__ float* v40 = v1 + v39; + + auto v41 = TPipe<0, Direction::DIR_C2V, 65536, 8, 1, false>(v36, v22, v22); + auto v42 = TPipe<2, Direction::DIR_C2V, 16384, 8, 1, false>(v38, v23, v22); + auto v43 = TPipe<4, Direction::DIR_V2C, 32768, 8, 8, false>(v40, v22, v24); + Tile v44; + TASSIGN(v44, v17); + Tile v45; + TASSIGN(v45, v17); + Tile v46; + TASSIGN(v46, v18); + Tile v47; + TASSIGN(v47, v17); + Tile v48; + TASSIGN(v48, v17); + Tile v49; + TASSIGN(v49, v18); + Tile v50; + TASSIGN(v50, v19); + Tile v51; + TASSIGN(v51, v17); + Tile v52; + TASSIGN(v52, v20); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + for (size_t v53 = (size_t) v34; v53 < ((size_t) ((int32_t) (uint32_t) v34 + (uint32_t) (v33 ? v32 : v30))); v53 += v25) { + pto::Shape<1, 1, 1, 32, 128> v54 = pto::Shape<1, 1, 1, 32, 128>(); + pto::Stride<4096, 4096, 4096, 128, 1> v55 = pto::Stride<4096, 4096, 4096, 128, 1>(); + GlobalTensor, pto::Stride<4096, 4096, 4096, 128, 1>, pto::Layout::ND> v56 = GlobalTensor, pto::Stride<4096, 4096, 4096, 128, 1>, pto::Layout::ND>(v2 + (v7 + (unsigned) ((int32_t) (uint32_t) ((int32_t) v53) * (uint32_t) v11) * (unsigned) v10 + v7 * (unsigned) v13), v54, v55); + TLOAD(v44, v56); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(v45, v44); + pto::Shape<1, 1, 1, 128, 512> v57 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<128, 128, 128, 1, 128> v58 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v59 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v7 + v7 * (unsigned) v13 + v7 * (unsigned) v10), v57, v58); + TLOAD(v46, v59); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + TMOV(v47, v46); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + TMATMUL(v48, v45, v47); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v41, v48); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + pto::Shape<1, 1, 1, 128, 512> v60 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<128, 128, 128, 1, 128> v61 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v62 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v7 + v7 * (unsigned) v13 + v6 * (unsigned) v10), v60, v61); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + TLOAD(v46, v62); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v47, v46); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID1); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + TMATMUL(v48, v45, v47); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID1); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v41, v48); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + pto::Shape<1, 1, 1, 512, 128> v63 = pto::Shape<1, 1, 1, 512, 128>(); + pto::Stride<65536, 65536, 65536, 128, 1> v64 = pto::Stride<65536, 65536, 65536, 128, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 128, 1>, pto::Layout::ND> v65 = GlobalTensor, pto::Stride<65536, 65536, 65536, 128, 1>, pto::Layout::ND>(v4 + (v7 + v7 * (unsigned) v10 + v7 * (unsigned) v13), v63, v64); + TLOAD(v50, v65); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + for (size_t v66 = (size_t) v22; v66 < ((size_t) v21); v66 += v25) { + int32_t v67 = (int32_t) ((uint32_t) ((int32_t) v66) * (uint32_t) v12); + pto::Shape<1, 1, 1, 128, 512> v68 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<128, 128, 128, 1, 128> v69 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v70 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v7 + v7 * (unsigned) v13 + (unsigned) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v67 + (uint32_t) v12) * (uint32_t) v9) * (unsigned) v10), v68, v69); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + TLOAD(v46, v70); + Tile v71; + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v43, v71); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID3); + TMOV(v49, v71); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v43); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + TMOV(v51, v50); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + int32_t v72 = (int32_t) ((uint32_t) v67 + (uint32_t) v13); + pto::Shape<1, 1, 1, 512, 128> v73 = pto::Shape<1, 1, 1, 512, 128>(); + pto::Stride<65536, 65536, 65536, 128, 1> v74 = pto::Stride<65536, 65536, 65536, 128, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 128, 1>, pto::Layout::ND> v75 = GlobalTensor, pto::Stride<65536, 65536, 65536, 128, 1>, pto::Layout::ND>(v4 + (v7 + (unsigned) ((int32_t) (uint32_t) v72 * (uint32_t) v9) * (unsigned) v10 + v7 * (unsigned) v13), v73, v74); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID4); + TLOAD(v50, v75); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + TMATMUL(v52, v49, v51); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID2); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v42, v52); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID4); + TMOV(v47, v46); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID3); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + TMATMUL(v48, v45, v47); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID3); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v41, v48); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + pto::Shape<1, 1, 1, 128, 512> v76 = pto::Shape<1, 1, 1, 128, 512>(); + pto::Stride<128, 128, 128, 1, 128> v77 = pto::Stride<128, 128, 128, 1, 128>(); + GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN> v78 = GlobalTensor, pto::Stride<128, 128, 128, 1, 128>, pto::Layout::DN>(v3 + (v7 + v7 * (unsigned) v13 + (unsigned) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v72 + (uint32_t) v12) * (uint32_t) v9) * (unsigned) v10), v76, v77); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID5); + TLOAD(v46, v78); + Tile v79; + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v43, v79); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID4); + TMOV(v49, v79); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v43); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID5); + TMOV(v51, v50); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + pto::Shape<1, 1, 1, 512, 128> v80 = pto::Shape<1, 1, 1, 512, 128>(); + pto::Stride<65536, 65536, 65536, 128, 1> v81 = pto::Stride<65536, 65536, 65536, 128, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 128, 1>, pto::Layout::ND> v82 = GlobalTensor, pto::Stride<65536, 65536, 65536, 128, 1>, pto::Layout::ND>(v4 + (v7 + (unsigned) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v72 + (uint32_t) v13) * (uint32_t) v9) * (unsigned) v10 + v7 * (unsigned) v13), v80, v81); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID6); + TLOAD(v50, v82); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID4); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID4); + TMATMUL(v52, v49, v51); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID4); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v42, v52); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID6); + TMOV(v47, v46); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID5); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID5); + TMATMUL(v48, v45, v47); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID5); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v41, v48); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + }; + set_flag(PIPE_FIX, PIPE_M, EVENT_ID6); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + Tile v83; + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v43, v83); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID5); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID5); + TMOV(v49, v83); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v43); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID7); + TMOV(v51, v50); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID7); + pto::Shape<1, 1, 1, 512, 128> v84 = pto::Shape<1, 1, 1, 512, 128>(); + pto::Stride<65536, 65536, 65536, 128, 1> v85 = pto::Stride<65536, 65536, 65536, 128, 1>(); + GlobalTensor, pto::Stride<65536, 65536, 65536, 128, 1>, pto::Layout::ND> v86 = GlobalTensor, pto::Stride<65536, 65536, 65536, 128, 1>, pto::Layout::ND>(v4 + (v7 + v5 * (unsigned) v10 + v7 * (unsigned) v13), v84, v85); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID7); + TLOAD(v50, v86); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID6); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID6); + TMATMUL(v52, v49, v51); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID6); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID6); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v42, v52); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID7); + Tile v87; + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v43, v87); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID6); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + TMOV(v49, v87); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v43); + TMOV(v51, v50); + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID7); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID7); + TMATMUL(v52, v49, v51); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_FIX, EVENT_ID7); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID7); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v42, v52); + } + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID3); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID2); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID3); + #endif // __DAV_CUBE__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +AICORE void vector_kernel(__gm__ float* v1, __gm__ float* v2) { + unsigned v3 = 0; + RoundMode v4 = RoundMode::CAST_RINT; + const int32_t v5 = 96; + const int32_t v6 = 128; + const int32_t v7 = 16; + const int32_t v8 = 32; + const int32_t v9 = 1; + const int32_t v10 = 229376; + const int32_t v11 = 131072; + const int32_t v12 = 163840; + const int64_t v13 = 114688; + const int64_t v14 = 147456; + const int64_t v15 = 163840; + const int64_t v16 = 180224; + const int64_t v17 = 188416; + const int64_t v18 = 188480; + const int64_t v19 = 188544; + const int64_t v20 = 188608; + const int64_t v21 = 188672; + const int64_t v22 = 188736; + const float v23 = 0.0883883461f; + const float v24 = 1.0f; + const int64_t v25 = 81920; + const int32_t v26 = 7; + const int32_t v27 = 0; + const int32_t v28 = 65536; + const int32_t v29 = 327680; + using T = float; + + #if defined(__DAV_VEC__) + set_mask_norm(); + set_vector_mask(-1, -1); + size_t v30 = (size_t) v9; + int64_t v31 = get_block_num(); + int32_t v32 = (int32_t) ((int64_t) v31); + int64_t v33 = get_block_idx(); + int32_t v34 = (int32_t) ((int64_t) v33); + int32_t v35 = v5 / v32; + int32_t v36 = v5 % v32; + int32_t v37 = (int32_t) ((uint32_t) v35 + (uint32_t) v9); + bool v38 = v34 < v36; + int32_t v39 = v38 ? (int32_t) ((uint32_t) v34 * (uint32_t) v37) : (int32_t) ((uint32_t) ((int32_t) (uint32_t) v36 * (uint32_t) v37) + (uint32_t) ((int32_t) (uint32_t) ((int32_t) (uint32_t) v34 - (uint32_t) v36) * (uint32_t) v35)); + int32_t v40 = (int32_t) ((uint32_t) v34 * (uint32_t) v10); + __gm__ float* v41 = v1 + v40; + + int32_t v42 = (int32_t) ((uint32_t) v40 + (uint32_t) v11); + __gm__ float* v43 = v1 + v42; + + int32_t v44 = (int32_t) ((uint32_t) v40 + (uint32_t) v12); + __gm__ float* v45 = v1 + v44; + + auto v46 = TPipe<0, Direction::DIR_C2V, 65536, 8, 1, false>(v41, v27, v27); + auto v47 = TPipe<2, Direction::DIR_C2V, 16384, 8, 1, false>(v43, v28, v27); + auto v48 = TPipe<4, Direction::DIR_V2C, 32768, 8, 8, false>(v45, v27, v29); + int64_t v49 = get_subblockid(); + Tile v50; + TASSIGN(v50, v13); + Tile v51; + TASSIGN(v51, v14); + Tile v52; + TASSIGN(v52, v15); + Tile v53; + TASSIGN(v53, v16); + Tile v54; + TASSIGN(v54, v17); + Tile v55; + TASSIGN(v55, v18); + Tile v56; + TASSIGN(v56, v19); + Tile v57; + TASSIGN(v57, v20); + Tile v58; + TASSIGN(v58, v21); + Tile v59; + TASSIGN(v59, v22); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID5); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID7); + for (size_t v60 = (size_t) v39; v60 < ((size_t) ((int32_t) (uint32_t) v39 + (uint32_t) (v38 ? v37 : v35))); v60 += v30) { + Tile v61; + TASSIGN(v61, v25); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v46, v61); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v61, v61, v23); + pipe_barrier(PIPE_V); + TROWMAX(v55, v61, v50); + Tile v62; + TRESHAPE(v62, v55); + Tile v63; + TRESHAPE(v63, v54); + pipe_barrier(PIPE_V); + TROWEXPANDSUB(v51, v61, v55); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TMULS(v63, v62, v24); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + pipe_barrier(PIPE_V); + TROWSUM(v56, v51, v50); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v48, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v46); + Tile v64; + TASSIGN(v64, v25); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v46, v64); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + TMULS(v64, v64, v23); + pipe_barrier(PIPE_V); + TROWMAX(v55, v64, v50); + Tile v65; + TRESHAPE(v65, v59); + Tile v66; + TRESHAPE(v66, v56); + Tile v67; + TRESHAPE(v67, v57); + pipe_barrier(PIPE_V); + TMAX(v62, v62, v63); + pipe_barrier(PIPE_V); + TSUB(v65, v63, v62); + pipe_barrier(PIPE_V); + TMULS(v63, v62, v24); + TROWEXPANDSUB(v51, v64, v55); + TEXP(v65, v65); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + TMUL(v66, v66, v65); + pipe_barrier(PIPE_V); + TROWSUM(v57, v51, v50); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + pipe_barrier(PIPE_V); + TADD(v66, v66, v67); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID1); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v48, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v46); + Tile v68; + TASSIGN(v68, v13); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v68); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + TMOV(v53, v68); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v47); + Tile v69; + TASSIGN(v69, v25); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v46, v69); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + TMULS(v69, v69, v23); + pipe_barrier(PIPE_V); + TROWMAX(v55, v69, v50); + Tile v70; + TRESHAPE(v70, v58); + pipe_barrier(PIPE_V); + TMAX(v62, v62, v63); + pipe_barrier(PIPE_V); + TSUB(v70, v63, v62); + pipe_barrier(PIPE_V); + TMULS(v63, v62, v24); + TROWEXPANDSUB(v51, v69, v55); + TEXP(v70, v70); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + TMUL(v66, v66, v70); + pipe_barrier(PIPE_V); + TROWSUM(v57, v51, v50); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + pipe_barrier(PIPE_V); + TADD(v66, v66, v67); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID3); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID2); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID2); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v48, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID4); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v46); + Tile v71; + TASSIGN(v71, v13); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v71); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + TROWEXPANDMUL(v53, v53, v59); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID4); + TADD(v53, v53, v71); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v47); + Tile v72; + TASSIGN(v72, v25); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v46, v72); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID5); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID5); + TMULS(v72, v72, v23); + pipe_barrier(PIPE_V); + TROWMAX(v55, v72, v50); + pipe_barrier(PIPE_V); + TMAX(v62, v62, v63); + pipe_barrier(PIPE_V); + TSUB(v65, v63, v62); + pipe_barrier(PIPE_V); + TMULS(v63, v62, v24); + TROWEXPANDSUB(v51, v72, v55); + TEXP(v65, v65); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + TMUL(v66, v66, v65); + pipe_barrier(PIPE_V); + TROWSUM(v57, v51, v50); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + pipe_barrier(PIPE_V); + TADD(v66, v66, v67); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID4); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v48, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v46); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID5); + for (size_t v73 = v30; v73 < ((size_t) v26); v73 += v30) { + Tile v74; + TASSIGN(v74, v13); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v74); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID6); + TROWEXPANDMUL(v53, v53, v58); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID6); + TADD(v53, v53, v74); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v47); + Tile v75; + TASSIGN(v75, v25); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v46, v75); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v75, v75, v23); + pipe_barrier(PIPE_V); + TROWMAX(v55, v75, v50); + pipe_barrier(PIPE_V); + TMAX(v62, v62, v63); + pipe_barrier(PIPE_V); + TSUB(v70, v63, v62); + pipe_barrier(PIPE_V); + TMULS(v63, v62, v24); + TROWEXPANDSUB(v51, v75, v55); + TEXP(v70, v70); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + TMUL(v66, v66, v70); + pipe_barrier(PIPE_V); + TROWSUM(v57, v51, v50); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID6); + pipe_barrier(PIPE_V); + TADD(v66, v66, v67); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID4); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID4); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v48, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v46); + Tile v76; + TASSIGN(v76, v13); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID6); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v76); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TROWEXPANDMUL(v53, v53, v59); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v53, v53, v76); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v47); + Tile v77; + TASSIGN(v77, v25); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v46, v77); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TMULS(v77, v77, v23); + pipe_barrier(PIPE_V); + TROWMAX(v55, v77, v50); + pipe_barrier(PIPE_V); + TMAX(v62, v62, v63); + pipe_barrier(PIPE_V); + TSUB(v65, v63, v62); + pipe_barrier(PIPE_V); + TMULS(v63, v62, v24); + TROWEXPANDSUB(v51, v77, v55); + TEXP(v65, v65); + pipe_barrier(PIPE_V); + TEXP(v51, v51); + TMUL(v66, v66, v65); + pipe_barrier(PIPE_V); + TROWSUM(v57, v51, v50); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + pipe_barrier(PIPE_V); + TADD(v66, v66, v67); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID2); + TCVT(v52, v51, v4); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID5); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID5); + TPUSH, Tile, TileSplitAxis::TILE_UP_DOWN>(v48, v52); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v46); + }; + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + Tile v78; + TASSIGN(v78, v13); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID7); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v78); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TROWEXPANDMUL(v53, v53, v58); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v53, v53, v78); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v47); + Tile v79; + TASSIGN(v79, v13); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + TPOP, Tile, TileSplitAxis::TILE_UP_DOWN>(v47, v79); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_V); + TROWEXPANDMUL(v53, v53, v59); + pipe_barrier(PIPE_V); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + TADD(v53, v53, v79); + TFREE, TileSplitAxis::TILE_UP_DOWN>(v47); + pipe_barrier(PIPE_V); + TROWEXPANDDIV(v53, v53, v56); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID6); + pto::Shape<1, 1, 1, 16, 128> v80 = pto::Shape<1, 1, 1, 16, 128>(); + pto::Stride<2048, 2048, 2048, 128, 1> v81 = pto::Stride<2048, 2048, 2048, 128, 1>(); + GlobalTensor, pto::Stride<2048, 2048, 2048, 128, 1>, pto::Layout::ND> v82 = GlobalTensor, pto::Stride<2048, 2048, 2048, 128, 1>, pto::Layout::ND>(v2 + (v3 + (unsigned) ((int32_t) (uint32_t) ((int32_t) (uint32_t) ((int32_t) v60) * (uint32_t) v8) + (uint32_t) ((int32_t) (uint32_t) ((int32_t) (int64_t) v49) * (uint32_t) v7)) * (unsigned) v6 + v3 * (unsigned) v9), v80, v81); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID6); + TSTORE(v82, v53); + } + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID4); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID5); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID6); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID7); + #endif // __DAV_VEC__ + + ptoas_auto_sync_tail(PTOAutoSyncTailMode::kBarrierAll); + return; +} + +__global__ AICORE void call_both(__gm__ int64_t* v1, __gm__ float* v2, __gm__ half* v3, __gm__ half* v4, __gm__ half* v5, __gm__ float* v6) { + using T = float; + uint64_t v7 = (uint64_t) v1; + set_ffts_base_addr(v7); + cube_kernel(v2, v3, v4, v5); + vector_kernel(v2, v6); + return; +} diff --git a/examples/aot/flash_attention/split_pipe/codegen_results/fa_performance_kernel.reference.cpp b/examples/aot/flash_attention/split_pipe/codegen_results/fa_performance_kernel.reference.cpp new file mode 100644 index 00000000..ef96b38e --- /dev/null +++ b/examples/aot/flash_attention/split_pipe/codegen_results/fa_performance_kernel.reference.cpp @@ -0,0 +1,998 @@ +/** +Copyright (c) 2026 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under the terms and conditions of +CANN Open Software License Agreement Version 2.0 (the "License"). +Please refer to the License for details. You may not use this file except in compliance with the License. +THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +See LICENSE in the root of the software repository for the full text of the License. +*/ + +#include +#include + +#include "fa_performance_kernel.h" +#include +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) +#include +#elif defined(__DAV_C310_CUBE__) || defined(__DAV_C310_VEC__) +#include +#endif +#include "pto_macro_matmul.hpp" +#include "pto_macro_fa_softmax.hpp" +#include "pto_macro_fa_gu.hpp" + +#define UF_ENABLE 1 + +using namespace std; +using namespace pto; + +#ifndef FFTS_BUFFER_FLAG_ENUM +#define FFTS_BUFFER_FLAG_ENUM +// Buffer flag values for FFTS pipeline coordination +enum FftsBufferFlag : uint32_t +{ + BUF0_QK_READY = 0, // Buffer 0: QK data ready + BUF0_SM_CONSUMED = 1, // Buffer 0: Softmax consumed + BUF1_SM_READY = 2, // Buffer 1: Softmax output ready + BUF1_SV_CONSUMED = 3, // Buffer 1: SV consumed + UPDATE_READY = 4, // Update stage ready + UPDATE_CONSUMED = 5, // Update stage consumed + CV_BLOCK_END = 7, // CV comm slot block end (CV_COMM_CTRL reserved in TSyncCVID) +}; +#endif + +enum CoreEvtID : uint32_t +{ + QK_EVENT_ID0, + QK_EVENT_ID1, + PV_EVENT_ID0, + PV_EVENT_ID1, +}; + +// ----------------------------------------------------------------------------- +// Performance tuning knobs (high-level) +// +// The kernel is a cross-core pipeline (Cube + Vec) with explicit FIFOs: +// QK (Cube): compute_qk -> qk_tile_fifo (fp32) +// P (Vec): compute_p -> p_tile_fifo (fp16 x_exp) + l1_exp_max_ififo +// PV (Cube): compute_pv -> pv_tile_fifo (fp32) +// GU (Vec): compute_gu -> o_out (fp32) with running rescale/update +// +// Key knobs that impact throughput (see runTFA<> below): +// - CUBE_S0 / CUBE_S1: tile sizes for QK/PV cube matmuls (compute intensity vs. buffer pressure) +// - qkPreloadNum: pipeline warmup depth (more overlap vs. more L1 FIFO footprint) +// - *_TNBuffers: ping/pong depth for Mat tiles (overlap) and Vec tiles (latency hiding) +// - QKV_CV_FIFO / PV_CV_FIFO: FIFO depth between stages (avoid backpressure) +// ----------------------------------------------------------------------------- + +// Inline macro used for small, performance-sensitive functions +#ifndef PTO_INLINE +#define PTO_INLINE __attribute__((always_inline)) inline +#endif + +// Detect build-time macros and expose as constexpr flags for clearer conditionals +#ifdef __DAV_CUBE__ +constexpr bool DAV_CUBE = true; +#else +constexpr bool DAV_CUBE = false; +#endif + +#ifdef __DAV_VEC__ +constexpr bool DAV_VEC = true; +#else +constexpr bool DAV_VEC = false; +#endif + +constexpr std::size_t MAX_TILE_L1_BYTES = 512U * 1024U; +constexpr std::size_t MAX_VEC_UB_BYTES = 192U * 1024U; + +template +constexpr AICORE std::size_t tile_storage_bytes() +{ + using ElementType = typename TileType::DType; + return static_cast(TileType::Rows * TileType::Cols) * sizeof(ElementType); +} + +template +constexpr AICORE std::size_t tile_buffer_total_bytes() +{ + return tile_storage_bytes() * NumBuffers; +} + +template +AICORE inline uint32_t assign_tile_buffers(TileType (&tiles)[NumBuffers], uint32_t base_offset) +{ + if constexpr (NumBuffers == 0) { + return base_offset; + } + + constexpr std::size_t total_storage_bytes = tile_buffer_total_bytes(); + static_assert(total_storage_bytes <= MAX_TILE_L1_BYTES, "Tile buffer L1 allocation exceeds 512KB"); + + for (std::size_t idx = 0; idx < NumBuffers; ++idx) { + const uint32_t tile_offset = base_offset + static_cast(idx * tile_storage_bytes()); + TASSIGN(tiles[idx], tile_offset); + } + + return base_offset + static_cast(total_storage_bytes); +} + +template +AICORE inline uint32_t assign_tile_buffers_union(TileA (&tilesA)[NumA], TileB (&tilesB)[NumB], uint32_t base_offset) +{ + static_assert(NumA == NumB, "Union assignment expects matching buffer counts"); + if constexpr (NumA == 0) { + return base_offset; + } + + constexpr std::size_t stride_bytes = (tile_storage_bytes() > tile_storage_bytes()) ? + tile_storage_bytes() : + tile_storage_bytes(); + constexpr std::size_t total_storage_bytes = stride_bytes * NumA; + static_assert(total_storage_bytes <= MAX_VEC_UB_BYTES, "Union tile UB allocation exceeds 192KB"); + + for (std::size_t idx = 0; idx < NumA; ++idx) { + const uint32_t tile_offset = base_offset + static_cast(idx * stride_bytes); + TASSIGN(tilesA[idx], tile_offset); + TASSIGN(tilesB[idx], tile_offset); + } + + return base_offset + static_cast(total_storage_bytes); +} + +template +AICORE inline void allocate_cube_tile_buffers(TileQType (&qTiles)[NumQ], TileKType (&kTiles)[NumK], + TilePType (&pTiles)[NumP], TileVType (&vTiles)[NumV]) +{ + constexpr std::size_t total_bytes = + tile_buffer_total_bytes() + tile_buffer_total_bytes() + + tile_buffer_total_bytes() + tile_buffer_total_bytes(); + static_assert(total_bytes <= MAX_TILE_L1_BYTES, "Total cube L1 allocation exceeds 512KB"); + + uint32_t l1_offset = 0; + l1_offset = assign_tile_buffers(qTiles, l1_offset); + l1_offset = assign_tile_buffers(kTiles, l1_offset); + l1_offset = assign_tile_buffers(pTiles, l1_offset); + l1_offset = assign_tile_buffers(vTiles, l1_offset); + (void)l1_offset; +} + +template +AICORE inline void allocate_vec_tile_buffers(TileDataF_T (&srcTiles)[SrcBuffers], ReduceTileF_T &m1_local_max, + TileDataF_T &input_reduce_tmp, ReduceTileF_T &l1_local_sum, + ReduceTileF_T &m2_global_max, ReduceTileF_T &l2_global_sum, + ReduceTileF_T (&l1_exp_max)[ExpMaxBuffers], + TileDataH_T (&x_expT)[XexpBuffers], TileOutT (&pvTile)[pvVecBuffers], + TileOutT &runningOTile, TileDataF_T &triu) +{ + constexpr std::size_t float_tile_bytes = tile_storage_bytes(); + constexpr std::size_t reduce_tile_bytes = tile_storage_bytes(); + constexpr std::size_t xexp_bytes = tile_buffer_total_bytes(); + constexpr std::size_t out_tile_bytes = tile_storage_bytes(); + constexpr std::size_t union_stride = (tile_storage_bytes() > tile_storage_bytes()) ? + tile_storage_bytes() : + tile_storage_bytes(); + static_assert(SrcBuffers == pvVecBuffers, "src/pv ping-pong buffer counts must match for union allocation"); + constexpr std::size_t union_bytes = union_stride * SrcBuffers; + constexpr std::size_t total_bytes = union_bytes + xexp_bytes + (reduce_tile_bytes * (3U + ExpMaxBuffers)) + + (float_tile_bytes / 8 * 1U) + (float_tile_bytes * 1U) + out_tile_bytes; + static_assert(total_bytes <= MAX_VEC_UB_BYTES, "Vec tile UB allocation exceeds 192KB"); + + uint32_t offset = 0; + TASSIGN(runningOTile, offset); + offset += out_tile_bytes; + offset = assign_tile_buffers_union(srcTiles, pvTile, offset); + + TASSIGN(m1_local_max, offset); + offset += static_cast(reduce_tile_bytes); + + TASSIGN(m2_global_max, offset); + offset += static_cast(reduce_tile_bytes); + + uint32_t tmp_float_offset = offset; + TASSIGN(input_reduce_tmp, tmp_float_offset); + offset += static_cast(float_tile_bytes) / 8; + + TASSIGN(triu, offset); + offset += static_cast(float_tile_bytes); + + TASSIGN(l1_local_sum, offset); + offset += static_cast(reduce_tile_bytes); + + TASSIGN(l2_global_sum, offset); + offset += static_cast(reduce_tile_bytes); + + offset = assign_tile_buffers(l1_exp_max, offset); + + uint32_t tail_offset = assign_tile_buffers(x_expT, offset); + + (void)tail_offset; +} + +// Helper to assign an accumulator tile to one of two ping-pong UB addresses (0x0 / 0x10000). +// Keeps a per-type static running index that toggles on every call. Caller may pass +// `initial_id` (0 or 1) to set the starting buffer index on the first call for that tile type. +template +AICORE inline int assign_running_acc_tile(AccTileT &accTile, int initial_id = -1) +{ + static int running_tile_buffer_idx = 0; // per-instantiation running buffer index: 0 -> base0, 1 -> base1 + if (initial_id == 0 || initial_id == 1) { + running_tile_buffer_idx = initial_id; + } + const int id = running_tile_buffer_idx; + const uint32_t base_addr = (id == 0) ? 0x0u : 0x10000u; + TASSIGN(accTile, base_addr); + running_tile_buffer_idx ^= 1; // toggle for next call + return id; +} + +template +AICORE inline void compute_qk(QKPipe &qkPipe, int tile_id, int sub_tile_id, __gm__ half *q, __gm__ half *k, + TileMatQData &qMatTile, TileMatKData &kMatTile, TileQKData &qkAccTile, + QKSlotGlobal &qkSlotGlobal, uint64_t qkMatTileEventId, int accTileEvtID, int blk_idx) +{ + if constexpr (DAV_CUBE) { + constexpr uint32_t Cube_S0 = CUBE_S0; + constexpr uint32_t Cube_S1 = CUBE_S1; + constexpr uint32_t Tile_S1 = TILE_S1; + constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1; + constexpr uint32_t Cube_HEAD = HEAD_SIZE; + + constexpr int QKP_CV_FIFO = QKPipe::RingFiFo::SLOT_NUM; + static_assert(QKP_CV_FIFO >= 1, "QKP_CV_FIFO must be >= 1"); + static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1"); + + const int s0_index = blk_idx * CUBE_S0; + const int s1_index = tile_id * static_cast(Tile_S1) + sub_tile_id * static_cast(Cube_S1); + if (sub_tile_id == 0) { + TALLOC(qkPipe, qkSlotGlobal); + } + if constexpr (CAUSAL_MASK) { + if (s1_index > s0_index) { + if (sub_tile_id == static_cast(kTileFactor) - 1) { + TPUSH(qkPipe, qkSlotGlobal); + } + return; + } + } + using GlobalDataQ = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + using GlobalDataK = GlobalTensor, + pto::Stride<1, 1, 1, 1, HEAD_SIZE>, Layout::DN>; // BNSD - (N, K) layout + + GlobalDataQ qGlobal(q); + GlobalDataK kGlobal(k + s1_index * HEAD_SIZE); + + wait_flag(PIPE_MTE1, PIPE_MTE2, qkMatTileEventId); + + if (tile_id == 0 && sub_tile_id == 0) { + TLOAD(qMatTile, qGlobal); + } + + TLOAD(kMatTile, kGlobal); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + +#if UF_ENABLE + pto_macro_matmul(qMatTile, kMatTile, qkAccTile, AccMode::InitFinalSum); +#else + wait_flag(PIPE_FIX, PIPE_M, accTileEvtID); + pto_macro_matmul(qMatTile, kMatTile, qkAccTile, AccMode::Init); +#endif + + set_flag(PIPE_MTE1, PIPE_MTE2, qkMatTileEventId); +#if !UF_ENABLE + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); +#endif + + using QKStoreGlobal = + GlobalTensor, pto::Stride<1, 1, 1, Tile_S1, 1>>; + QKStoreGlobal qkStoreGlobal(qkSlotGlobal.data() + static_cast(sub_tile_id) * Cube_S1); +#if UF_ENABLE + TSTORE(qkStoreGlobal, qkAccTile); +#else + TSTORE(qkStoreGlobal, qkAccTile); +#endif + + if (sub_tile_id == static_cast(kTileFactor) - 1) { + TPUSH(qkPipe, qkSlotGlobal); + } + +#if !UF_ENABLE + set_flag(PIPE_FIX, PIPE_M, accTileEvtID); +#endif + } +} + +template +AICORE inline void compute_pv(PPipe &pPipe, PVPipe &pvPipe, int tile_id, int sub_tile_id, __gm__ half *v, + TileMatPData &pMatTile, TileMatVData &vMatTile, TilePVData &pvAccTile, + PSlotGlobal &pSlotGlobal, PVSlotGlobal &pvSlotGlobal, uint64_t svMatTileEventId, + int accTileEvtID, int blk_idx) +{ + constexpr uint32_t Cube_S0 = CUBE_S0; + constexpr uint32_t Cube_S1 = CUBE_S1; + constexpr uint32_t Tile_S1 = TILE_S1; + constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1; + constexpr uint32_t Cube_HEAD = HEAD_SIZE; + constexpr uint32_t TileElems = Cube_S0 * Tile_S1; + constexpr int QKP_CV_FIFO = PVPipe::RingFiFo::SLOT_NUM; + static_assert(QKP_CV_FIFO >= 1, "PV_CV_FIFO must be >= 1"); + static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1"); + + const int s0_index = blk_idx * Cube_S0; + const int s1_index = tile_id * static_cast(Tile_S1) + sub_tile_id * static_cast(Cube_S1); + const bool is_last_subtile = (sub_tile_id + 1 == static_cast(kTileFactor)); + const bool next_will_be_skipped = (s1_index + static_cast(Cube_S1)) > s0_index && CAUSAL_MASK; + + if constexpr (DAV_CUBE) { + if (sub_tile_id == 0) { + TPOP(pPipe, pSlotGlobal); + } + if constexpr (CAUSAL_MASK) { + if (s1_index > s0_index) { + if (is_last_subtile) { + TFREE(pPipe, pSlotGlobal); + } + return; + } + } + + using GlobalVT = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + + wait_flag(PIPE_MTE1, PIPE_MTE2, svMatTileEventId); + + GlobalVT vLoad((__gm__ half *)(v + s1_index * HEAD_SIZE)); + TLOAD(vMatTile, vLoad); + + using PLoadGlobal = + GlobalTensor, pto::Stride<1, 1, 1, Tile_S1, 1>>; + PLoadGlobal pLoadGlobal(pSlotGlobal.data() + static_cast(sub_tile_id) * Cube_S1); + TLOAD(pMatTile, pLoadGlobal); + if (is_last_subtile) { + TFREE(pPipe, pSlotGlobal); + } + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + +#if !UF_ENABLE + if (sub_tile_id == 0) { + wait_flag(PIPE_FIX, PIPE_M, accTileEvtID); + } +#endif + +#if UF_ENABLE + const AccMode accMode = + (sub_tile_id == 0) ? + (is_last_subtile || next_will_be_skipped ? AccMode::InitFinalSum : AccMode::InitPartialSum) : + (is_last_subtile || next_will_be_skipped ? AccMode::AccFinalSum : AccMode::AccPartialSum); + pto_macro_matmul(pMatTile, vMatTile, pvAccTile, accMode); +#else + const AccMode accMode = (sub_tile_id == 0) ? AccMode::Init : AccMode::Acc; + pto_macro_matmul(pMatTile, vMatTile, pvAccTile, accMode); +#endif + + set_flag(PIPE_MTE1, PIPE_MTE2, svMatTileEventId); + + if (sub_tile_id == static_cast(kTileFactor) - 1 || next_will_be_skipped) { +#if !UF_ENABLE + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); +#endif + + using PVStoreGlobal = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + TALLOC(pvPipe, pvSlotGlobal); + PVStoreGlobal pvStoreGlobal(pvSlotGlobal.data()); +#if UF_ENABLE + TSTORE(pvStoreGlobal, pvAccTile); +#else + TSTORE(pvStoreGlobal, pvAccTile); +#endif + TPUSH(pvPipe, pvSlotGlobal); + +#if !UF_ENABLE + set_flag(PIPE_FIX, PIPE_M, accTileEvtID); +#endif + } // end loop + } // end if DAV_CUBE +} + +template +AICORE inline void compute_p(QKPipe &qkPipe, PPipe &pPipe, int tile_id, int row_slice, __gm__ float *exp_max_ififo, + __gm__ float *global_sum_out, __gm__ float *exp_max_out, TileDataF_T &qkVecTile, + TileDataH_T &x_expT, TileDataF_T &input_reduce_tmp, ReduceTileF_T &m1_local_max, + ReduceTileF_T &l1_local_sum, ReduceTileF_T &m2_global_max, ReduceTileF_T &l2_global_sum, + ReduceTileF_T &l1_exp_max_ififo, TileDataF_T triu, QKSlotGlobal &qkSlotGlobal, + PSlotGlobal &pSlotGlobal, uint64_t pTileEventId, int blk_idx) +{ + constexpr uint32_t Cube_S0 = CUBE_S0; + constexpr uint32_t Cube_S1 = CUBE_S1; + constexpr uint32_t Tile_S1 = TILE_S1; + constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1; + constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor; + constexpr int QKP_CV_FIFO = QKPipe::RingFiFo::SLOT_NUM; + static_assert(QKP_CV_FIFO >= 1, "QKP_CV_FIFO must be >= 1"); + static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1"); + static_assert(Cube_S0 % (VEC_CORES * kTileFactor) == 0, "Vec rows must divide evenly across tile slices"); + const bool initFlag = (tile_id == 0); + if constexpr (DAV_VEC) { + const size_t subblock_base_rows = + static_cast(Cube_S0 / VEC_CORES) * static_cast(get_subblockid()); + const size_t row_offset = subblock_base_rows + static_cast(row_slice * Vec_S0); + const int s0_index = blk_idx * Cube_S0 + row_offset; + const int s1_index = tile_id * static_cast(Tile_S1); + wait_flag(PIPE_V, PIPE_MTE2, pTileEventId); + + if (row_slice == 0) { + TPOP(qkPipe, qkSlotGlobal); + } + + using QKLoadGlobal = + GlobalTensor, pto::Stride<1, 1, 1, Tile_S1, 1>>; + QKLoadGlobal qkLoadGlobal(qkSlotGlobal.data() + row_offset * Tile_S1); + TLOAD(qkVecTile, qkLoadGlobal); + if (row_slice == static_cast(kTileFactor) - 1) { + TFREE(qkPipe, qkSlotGlobal); + } + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + // Extract per-slice views into the per-core reduce tiles so each slice writes into its row range + using ReduceSliceTile = Tile; + // reduce tiles live per vector core; offset only by row_slice within the core (no subblock stride) + const size_t reduce_slice_rows = static_cast(row_slice * Vec_S0); + const uint64_t reduce_row_byte_offset = reduce_slice_rows * sizeof(float); + + ReduceSliceTile m1_local_max_slice; + ReduceSliceTile l1_local_sum_slice; + ReduceSliceTile m2_global_max_slice; + ReduceSliceTile l2_global_sum_slice; + ReduceSliceTile l1_exp_max_slice; + + TASSIGN(m1_local_max_slice, (uint64_t)m1_local_max.data() + reduce_row_byte_offset); + TASSIGN(l1_local_sum_slice, (uint64_t)l1_local_sum.data() + reduce_row_byte_offset); + TASSIGN(m2_global_max_slice, (uint64_t)m2_global_max.data() + reduce_row_byte_offset); + TASSIGN(l2_global_sum_slice, (uint64_t)l2_global_sum.data() + reduce_row_byte_offset); + TASSIGN(l1_exp_max_slice, (uint64_t)l1_exp_max_ififo.data() + reduce_row_byte_offset); + + // Extract current slice state from full-length reduce tiles + wait_flag(PIPE_MTE3, PIPE_V, pTileEventId); + if (initFlag) { + pto_macro_fa_softmax( + x_expT, qkVecTile, m1_local_max_slice, l1_local_sum_slice, m2_global_max_slice, l2_global_sum_slice, + l1_exp_max_slice, input_reduce_tmp, qkVecTile, triu, s0_index, s1_index); + } else { + pto_macro_fa_softmax( + x_expT, qkVecTile, m1_local_max_slice, l1_local_sum_slice, m2_global_max_slice, l2_global_sum_slice, + l1_exp_max_slice, input_reduce_tmp, qkVecTile, triu, s0_index, s1_index); + } + + set_flag(PIPE_V, PIPE_MTE2, pTileEventId); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + + if (row_slice == 0) { + TALLOC(pPipe, pSlotGlobal); + } + using PStoreGlobal = + GlobalTensor, pto::Stride<1, 1, 1, Tile_S1, 1>>; + PStoreGlobal pStoreGlobal(pSlotGlobal.data() + row_offset * Tile_S1); + TSTORE(pStoreGlobal, x_expT); + if (row_slice == static_cast(kTileFactor) - 1) { + TPUSH(pPipe, pSlotGlobal); + } + + set_flag(PIPE_MTE3, PIPE_V, pTileEventId); + if constexpr (INTERMEDIATE_CHECK) { + // On the final row_slice, emit the exp_max for this subblock only (Cube_S0 / VEC_CORES rows) + if (row_slice == static_cast(kTileFactor) - 1) { + constexpr uint32_t SubblockRows = Cube_S0 / VEC_CORES; + using GlobalPMaxFloatSub = + GlobalTensor, pto::Stride<1, 1, 1, Cube_S0, 1>>; + using ExpMaxSub = Tile; + const size_t base_elems_pmax = + static_cast(tile_id % QKP_CV_FIFO) * static_cast(Cube_S0) + subblock_base_rows; + __gm__ float *p_ptr_fp32 = exp_max_ififo + base_elems_pmax; + GlobalPMaxFloatSub pMaxGlobal(p_ptr_fp32); + ExpMaxSub l1_exp_max_rowmajor; + TRESHAPE(l1_exp_max_rowmajor, l1_exp_max_ififo); + TSTORE(pMaxGlobal, l1_exp_max_rowmajor); + } + } + } +} + +template +AICORE inline void compute_gu(PVPipe &pvPipe, int tile_id, int num_tiles, __gm__ float *o_out, + __gm__ float *o_parts_out, TileOutT &runningOTile, TileOutT &pvVecTile, + ReduceTileF_T &l1_exp_max_ififo, ReduceTileF_T &l2_global_sum, uint64_t guEventId) +{ + constexpr uint32_t Cube_S0 = CUBE_S0; + constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES; + + if constexpr (DAV_VEC) { + wait_flag(PIPE_V, PIPE_MTE2, guEventId); + const size_t subblock_base_rows = + static_cast(Cube_S0 / VEC_CORES) * static_cast(get_subblockid()); + + using PVVecGlobal = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + PVVecGlobal pvGlobal; + TPOP(pvPipe, pvGlobal); + + if (tile_id == 0) { + TLOAD(runningOTile, pvGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + if constexpr (CAUSAL_MASK) { + if (tile_id == num_tiles - 1) + pto_macro_fa_gu_single_and_last_tile(runningOTile, l2_global_sum); + } + } else { + TLOAD(pvVecTile, pvGlobal); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + if (tile_id < num_tiles - 1) { + pto_macro_fa_gu(runningOTile, pvVecTile, l1_exp_max_ififo); + } else { + pto_macro_fa_gu_last(runningOTile, pvVecTile, l1_exp_max_ififo, l2_global_sum); + } + } + TFREE(pvPipe, pvGlobal); + + set_flag(PIPE_V, PIPE_MTE2, guEventId); + + if (tile_id == num_tiles - 1) { + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + using GlobalOutT = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + GlobalOutT outGlobal((__gm__ float *)(o_out + subblock_base_rows * HEAD_SIZE)); + TSTORE(outGlobal, runningOTile); + } + } +} + +template +__global__ AICORE void runTFA(__gm__ uint64_t *ffts_addr, __gm__ half *q, __gm__ half *k, __gm__ half *v, + __gm__ half *p_tile_fifo, __gm__ float *exp_max_ififo, __gm__ float *global_sum_out, + __gm__ float *exp_max_out, __gm__ float *o_out, __gm__ float *o_parts_out, + __gm__ float *qk_tile_fifo, __gm__ float *pv_tile_fifo, __gm__ uint8_t *cv_comm_buf, + __gm__ uint8_t *profile_buf) +{ + uint64_t tStart = get_sys_cnt(); + + set_ffts_base_addr((uint64_t)ffts_addr); + if constexpr (DAV_CUBE) { + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + set_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + } + + // Rename dimensions for clarity: S0 (rows total), Cube_S0 (per-block rows), S1 (cols), HEAD_SIZE (inner) + constexpr uint32_t Cube_S0 = CUBE_S0; + constexpr uint32_t block_rows = S0 / CUBE_S0; + constexpr uint32_t Cube_S1 = CUBE_S1; // per-tile S1 chunk + constexpr uint32_t Tile_S1 = TILE_S1; // logical tile along S1 + static_assert(Tile_S1 % Cube_S1 == 0, "TILE_S1 must be divisible by CUBE_S1"); + constexpr uint32_t kTileFactor = Tile_S1 / Cube_S1; // sub-tiles per TILE_S1 + constexpr uint32_t Cube_HEAD = HEAD_SIZE; + constexpr uint32_t Vec_S0 = Cube_S0 / VEC_CORES / kTileFactor; + constexpr uint32_t VecGuRows = Cube_S0 / VEC_CORES; + static_assert(Cube_S0 % (VEC_CORES * kTileFactor) == 0, "Vec rows must divide evenly across tile slices"); + + // -------------------------- + // Tuning knobs (pipeline) + // + // qkPreloadNum controls how many (QK -> P) tiles we warm up before entering the steady-state loop. + // - Larger preload improves overlap (Cube/VEC concurrency) for long S1. + // - Larger preload increases FIFO footprint (qkGlobalTensorNBuffers / pvGlobalTensorNBuffers / + // guGlobalTensorNBuffers). + constexpr uint32_t qkPreloadNum = QK_PRELOAD; + + // Buffer counts for optional double-buffering (default 1) + // - srcVecTNBuffers/xexpVecTNBuffers: Vec ping-pong for QK load and x_exp output + // - *MatTNBuffers: L1 ping-pong for Cube stage (K/P/V) + // Keep these small (1-2) unless you have measured stall bubbles that require deeper buffering. + constexpr uint32_t srcVecTNBuffers = 2; + constexpr uint32_t xexpVecTNBuffers = 2; + constexpr uint32_t outOTileNBuffers = 2; + constexpr uint32_t qMatTNBuffers = 1; + constexpr uint32_t kMatTNBuffers = 2; + constexpr uint32_t pMatTNBuffers = 2; + constexpr uint32_t vMatTNBuffers = 2; + constexpr uint32_t qkp_tile_fifo_size = CV_FIFO_SIZE; + constexpr uint32_t pv_tile_fifo_size = CV_FIFO_SIZE; + static_assert(qkPreloadNum >= 1, "qkPreloadNum must be >= 1"); + static_assert(CV_FIFO_CONS_SYNC_PERIOD >= 1, "CV_FIFO_CONS_SYNC_PERIOD must be >= 1"); + static_assert((qkPreloadNum > 1) || (kTileFactor == 1), "qkPreloadNum must be > 1 unless kTileFactor == 1"); + + // Define tile types for first QK matmul + using TileMatQData = + Tile; + using TileMatKData = + Tile; + // Accumulator rows must match Cube_S0 (per-block rows), not logical S0 + using TileQKData = TileAcc; + + TileMatQData qMatTile[qMatTNBuffers]; + TileMatKData kMatTile[kMatTNBuffers]; + TileQKData qkAccTile; + + // Define tile types for second PV matmul + using TileMatPData = + Tile; + using TileMatVData = + Tile; + using TilePVData = TileAcc; + + TileMatPData pMatTile[pMatTNBuffers]; + TileMatVData vMatTile[vMatTNBuffers]; + TilePVData pvAccTile; + + allocate_cube_tile_buffers(qMatTile, kMatTile, pMatTile, vMatTile); + + // Assign accumulator tiles using ping-pong helper. qk starts at 0, pv starts at 1. + assign_running_acc_tile(qkAccTile, 0); + assign_running_acc_tile(pvAccTile, 1); + + // Define tile types for FA softmax P computation + // UB offsets for softmax tiles + // Define per-tile vector tiles sized to Cube_S1 + using TileDataF_T = Tile; + using TileDataH_T = Tile; + constexpr uint32_t SubblockRows = Cube_S0 / VEC_CORES; + // Reduce tiles cover one vector core's rows (Cube_S0 / VEC_CORES); slices are extracted per row_slice + using ReduceTileF_T = Tile; + + TileDataF_T qkVecTile[srcVecTNBuffers]; + ReduceTileF_T m1_local_max; + TileDataF_T input_reduce_tmp; + TileDataF_T triu; + ReduceTileF_T l1_local_sum; + ReduceTileF_T m2_global_max; + ReduceTileF_T l2_global_sum; + ReduceTileF_T l1_exp_max_ififo[qkp_tile_fifo_size]; + TileDataH_T x_expT[xexpVecTNBuffers]; + + using TileOutGuT = Tile; + TileOutGuT pvVecTile[outOTileNBuffers]; + TileOutGuT runningOTile; + allocate_vec_tile_buffers(qkVecTile, m1_local_max, input_reduce_tmp, l1_local_sum, m2_global_max, + l2_global_sum, l1_exp_max_ififo, x_expT, pvVecTile, runningOTile, triu); + + // block offset for logical S0 +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) // A5 defined macro, don't need to reassign + const int block_idx = get_block_idx(); +#endif + const int block_offset_rows = block_idx * static_cast(Cube_S0); + + constexpr bool use_cv_comm = (!INTERMEDIATE_CHECK) && (block_rows >= static_cast(pto::kCvMaxCores)); + int comm_slot = block_idx; + + if constexpr (use_cv_comm) { + comm_slot = pto::TSYNC_CVID(block_idx, cv_comm_buf); + } + __gm__ uint64_t *profile_entry = nullptr; + if (profile_buf != nullptr) { + std::size_t profile_block_base = static_cast(block_idx) * kFaProfileBytesPerBlock; + std::size_t profile_offset = profile_block_base; + if constexpr (DAV_VEC) { + profile_offset += + (static_cast(get_subblockid()) + 1U) * 1024U; // vec subblock 0/1 use 2nd/3rd KB + } + profile_entry = reinterpret_cast<__gm__ uint64_t *>(profile_buf + profile_offset); + profile_entry[0] = tStart; + } + const size_t p_fifo_block_stride = + static_cast(qkp_tile_fifo_size) * static_cast(Cube_S0) * static_cast(Tile_S1); + const size_t p_max_fifo_block_stride = static_cast(qkp_tile_fifo_size) * static_cast(Cube_S0); + const size_t qk_fifo_block_stride = p_fifo_block_stride; + const size_t pv_fifo_block_stride = + static_cast(pv_tile_fifo_size) * static_cast(Cube_S0) * static_cast(HEAD_SIZE); + + __gm__ half *q_block = q + block_offset_rows * HEAD_SIZE; + __gm__ half *p_tile_fifo_block = p_tile_fifo + static_cast(comm_slot) * p_fifo_block_stride; + __gm__ float *exp_max_ififo_block = exp_max_ififo + static_cast(comm_slot) * p_max_fifo_block_stride; + __gm__ float *global_sum_block = global_sum_out + block_offset_rows; + __gm__ float *exp_max_block = exp_max_out + block_offset_rows; + __gm__ float *o_out_block = o_out + static_cast(block_offset_rows) * static_cast(HEAD_SIZE); + __gm__ float *o_parts_block = o_parts_out + static_cast(block_offset_rows) * static_cast(HEAD_SIZE); + __gm__ float *qk_tile_fifo_block = qk_tile_fifo + static_cast(comm_slot) * qk_fifo_block_stride; + __gm__ float *pv_tile_fifo_block = pv_tile_fifo + static_cast(comm_slot) * pv_fifo_block_stride; + + int num_tiles_s1 = S1 / Tile_S1; + if constexpr (CAUSAL_MASK) + num_tiles_s1 = (1 + ((block_idx * CUBE_S0) / Tile_S1)); + if constexpr (DAV_CUBE) { + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + set_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + set_flag(PIPE_FIX, PIPE_M, EVENT_ID1); + } + if constexpr (DAV_VEC) { + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + set_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); + } + + int p_gu_src_pingpong_id = 0; // shared ping-pong for softmax vec tiles, pv output tiles, and GU input tiles + int k_src_pingpong_id = 0; // separate ping-pong for K tiles + int pv_src_pingpong_id = 0; // separate ping-pong for P V tiles + + int qkAccTileEvtID = 0; + int pvAccTileEvtID = 0; + + // FIFO definitions + constexpr uint8_t FiFoDepth = CV_FIFO_SIZE; + #if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) + using QKPipe = + TPipe; + #elif defined(__DAV_C310_CUBE__) || defined(__DAV_C310_VEC__) + using QKPipe = + TPipe; + #endif + QKPipe qkPipe(qk_tile_fifo_block, (uint32_t)(uint64_t)qkVecTile[0].data(), 0x0); + + // pFiFo, pProd, pCons + #if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) + using PPipe = TPipe; + #elif defined(__DAV_C310_CUBE__) || defined(__DAV_C310_VEC__) + using PPipe = TPipe; + #endif + PPipe pPipe(p_tile_fifo_block, 0x0, (uint32_t)(uint64_t)pMatTile[0].data()); + + // pvFiFo, pvProd, pvCons + #if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) + using PVPipe = + TPipe; + #elif defined(__DAV_C310_CUBE__) || defined(__DAV_C310_VEC__) + using PVPipe = + TPipe; + #endif + PVPipe pvPipe(pv_tile_fifo_block, (uint32_t)(uint64_t)pvVecTile[0].data(), 0x0); + + using QKSlotGlobal = + GlobalTensor, pto::Stride<1, 1, 1, Tile_S1, 1>>; + using PSlotGlobal = + GlobalTensor, pto::Stride<1, 1, 1, Tile_S1, 1>>; + using PVSlotGlobal = + GlobalTensor, pto::Stride<1, 1, 1, HEAD_SIZE, 1>>; + QKSlotGlobal qkSlotGlobal; + PSlotGlobal pSlotGlobal; + PVSlotGlobal pvSlotGlobal; + + // QK and P pre-computation (tile_id based) + for (int preload_tile = 0; preload_tile < static_cast(qkPreloadNum) && preload_tile < num_tiles_s1; + ++preload_tile) { + if constexpr (DAV_CUBE) { + for (int sub_tile = 0; sub_tile < static_cast(kTileFactor); ++sub_tile) { + qkAccTileEvtID = assign_running_acc_tile(qkAccTile); + compute_qk( + qkPipe, preload_tile, sub_tile, q_block, k, qMatTile[0], + kMatTile[k_src_pingpong_id % kMatTNBuffers], qkAccTile, qkSlotGlobal, + k_src_pingpong_id % kMatTNBuffers, qkAccTileEvtID, block_idx); + k_src_pingpong_id++; + } + } + if constexpr (DAV_VEC) { + for (int row_slice = 0; row_slice < static_cast(kTileFactor); ++row_slice) { + // Init only on the very first S1 tile; row_slice partitions rows within that tile + compute_p( + qkPipe, pPipe, preload_tile, row_slice, exp_max_ififo_block, global_sum_block, exp_max_block, + qkVecTile[p_gu_src_pingpong_id % srcVecTNBuffers], x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], + input_reduce_tmp, m1_local_max, l1_local_sum, m2_global_max, l2_global_sum, + l1_exp_max_ififo[preload_tile % qkp_tile_fifo_size], triu, qkSlotGlobal, pSlotGlobal, + p_gu_src_pingpong_id % xexpVecTNBuffers, block_idx); + p_gu_src_pingpong_id++; + } + } + } + + for (int tile_id = 0; tile_id < num_tiles_s1; ++tile_id) { + int next_qk_tile = (tile_id + static_cast(qkPreloadNum) >= num_tiles_s1) ? + -1 : + (tile_id + static_cast(qkPreloadNum)); + + if (next_qk_tile != -1) + qkAccTileEvtID = assign_running_acc_tile(qkAccTile); + pvAccTileEvtID = assign_running_acc_tile(pvAccTile); + + for (int sub_tile = 0; sub_tile < static_cast(kTileFactor); ++sub_tile) { + if constexpr (DAV_CUBE) { + if (next_qk_tile != -1) { + compute_qk( + qkPipe, next_qk_tile, sub_tile, q_block, k, qMatTile[0], + kMatTile[k_src_pingpong_id % kMatTNBuffers], qkAccTile, qkSlotGlobal, + k_src_pingpong_id % kMatTNBuffers, qkAccTileEvtID, block_idx); + k_src_pingpong_id++; + } + } + + if constexpr (DAV_VEC) { + if (next_qk_tile != -1) { + compute_p( + qkPipe, pPipe, next_qk_tile, sub_tile, exp_max_ififo_block, global_sum_block, + exp_max_block, qkVecTile[p_gu_src_pingpong_id % srcVecTNBuffers], + x_expT[p_gu_src_pingpong_id % xexpVecTNBuffers], input_reduce_tmp, m1_local_max, + l1_local_sum, m2_global_max, l2_global_sum, + l1_exp_max_ififo[next_qk_tile % qkp_tile_fifo_size], triu, qkSlotGlobal, pSlotGlobal, + p_gu_src_pingpong_id % xexpVecTNBuffers, block_idx); + p_gu_src_pingpong_id++; + } + } + + if constexpr (DAV_CUBE) { + compute_pv( + pPipe, pvPipe, tile_id, sub_tile, v, pMatTile[pv_src_pingpong_id % pMatTNBuffers], + vMatTile[pv_src_pingpong_id % vMatTNBuffers], pvAccTile, + pSlotGlobal, pvSlotGlobal, pv_src_pingpong_id % vMatTNBuffers + PV_EVENT_ID0, pvAccTileEvtID, + block_idx); + pv_src_pingpong_id++; + } + } + + if constexpr (DAV_VEC) { + compute_gu( + pvPipe, tile_id, num_tiles_s1, o_out_block, o_parts_block, runningOTile, + pvVecTile[p_gu_src_pingpong_id % outOTileNBuffers], l1_exp_max_ififo[tile_id % qkp_tile_fifo_size], + l2_global_sum, p_gu_src_pingpong_id % outOTileNBuffers); + p_gu_src_pingpong_id++; + } + } + + if constexpr (DAV_CUBE) { + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_M, PIPE_MTE1, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_MTE1, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID0); + wait_flag(PIPE_FIX, PIPE_M, EVENT_ID1); +#ifdef __DAV_C220_CUBE__ + wait_flag_dev(CV_BLOCK_END); // wait for vector done all reading +#endif + } + + if constexpr (DAV_VEC) { + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_V, EVENT_ID1); +#ifdef __DAV_C220_VEC__ + ffts_cross_core_sync(PIPE_MTE2, _getFFTSMsg(CV_CORE_SYNC, CV_BLOCK_END)); // cube can exit CV comm now +#endif + } + + pipe_barrier(PIPE_ALL); + uint64_t tEnd = get_sys_cnt(); + if (profile_entry != nullptr) { + profile_entry[1] = tEnd; + } +#ifdef _DEBUG + if constexpr (DAV_CUBE) { + cce::printf("Core %d Cube Block %d, Start @%d End @%d (%d us)\n", get_coreid(), block_idx, int(tStart), + int(tEnd), int(tEnd - tStart) * 20 / 1000); + } else { + cce::printf("Core %d Vec Block %d, SubBlock %d, Start @%d End @%d (%d us)\n", get_coreid(), block_idx, + int(get_subblockid()), int(tStart), int(tEnd), int(tEnd - tStart) * 20 / 1000); + } +#endif +} + +// Empty kernel to warm up cores +__global__ AICORE __attribute__((aic)) void warmup_kernel() +{} + +// Host wrapper +template +void LaunchTFA(uint16_t *ffts, aclFloat16 *q, aclFloat16 *k, aclFloat16 *v, aclFloat16 *p_tile_fifo, + float *exp_max_ififo, float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, + float *qk_tile_fifo, float *pv_tile_fifo, uint8_t *profile_data, aclrtStream stream, + uint8_t *cv_comm_buf) +{ + static_assert(S0 % CUBE_S0 == 0, "S0 must be divisible by CUBE_S0"); + constexpr uint32_t block_rows = S0 / CUBE_S0; + +#if defined(__DAV_C220_CUBE__) || defined(__DAV_C220_VEC__) + // Warm up all cores first, then prefetch q/k/v into L2 + warmup_kernel<<<24, nullptr, stream>>>(); + + const uint64_t q_bytes = static_cast(S0) * static_cast(HEAD_SIZE) * sizeof(half); + const uint64_t kv_bytes = static_cast(S1) * static_cast(HEAD_SIZE) * sizeof(half); + constexpr bool kPrefetchUseSdma = true; // simulation cannot use sdma + constexpr int kPrefetchAivCores = 40; // only used when kPrefetchUseSdma is false + + if constexpr (kPrefetchUseSdma) { + PTO_PREFETCH((__gm__ void *)q, q_bytes, stream); + PTO_PREFETCH((__gm__ void *)k, kv_bytes, stream); + PTO_PREFETCH((__gm__ void *)v, kv_bytes, stream); + } else { + PTO_PREFETCH((__gm__ void *)q, q_bytes, stream); + PTO_PREFETCH((__gm__ void *)k, kv_bytes, stream); + PTO_PREFETCH((__gm__ void *)v, kv_bytes, stream); + } +#endif + + runTFA<<>>( + (__gm__ uint64_t *)ffts, (half *)q, (half *)k, (half *)v, (half *)p_tile_fifo, exp_max_ififo, global_sum_out, + exp_max_out, o_out, o_parts_out, qk_tile_fifo, pv_tile_fifo, cv_comm_buf, profile_data); +} + +// Backward-compatible overload without profiling buffer +template +void LaunchTFA(uint16_t *ffts, aclFloat16 *q, aclFloat16 *k, aclFloat16 *v, aclFloat16 *p_tile_fifo, + float *exp_max_ififo, float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, + float *qk_tile_fifo, float *pv_tile_fifo, aclrtStream stream, uint8_t *cv_comm_buf) +{ + LaunchTFA(ffts, q, k, v, p_tile_fifo, exp_max_ififo, global_sum_out, exp_max_out, o_out, + o_parts_out, qk_tile_fifo, pv_tile_fifo, nullptr, stream, cv_comm_buf); +} + +#include "generated_cases.h" + +#define INSTANTIATE_TFA(S0, HEAD, S1, CUBE_S0, CUBE_S1, TILE_S1, QK_PRELOAD, CAUSAL_MASK) \ + template void LaunchTFA( \ + uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \ + float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \ + uint8_t *profile_data, aclrtStream stream, uint8_t *cv_comm_buf); \ + template void LaunchTFA( \ + uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \ + float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \ + aclrtStream stream, uint8_t *cv_comm_buf); \ + template void LaunchTFA( \ + uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \ + float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \ + uint8_t *profile_data, aclrtStream stream, uint8_t *cv_comm_buf); \ + template void LaunchTFA( \ + uint16_t * ffts, aclFloat16 * q, aclFloat16 * k, aclFloat16 * v, aclFloat16 * p_out, float *p_out_fp32, \ + float *global_sum_out, float *exp_max_out, float *o_out, float *o_parts_out, float *qk_out, float *pv_out, \ + aclrtStream stream, uint8_t *cv_comm_buf); + +TFA_FOR_EACH_CASE(INSTANTIATE_TFA) + +#undef INSTANTIATE_TFA \ No newline at end of file diff --git a/examples/aot/flash_attention/split_pipe/compile.sh b/examples/aot/flash_attention/split_pipe/compile.sh new file mode 100644 index 00000000..e23d4a61 --- /dev/null +++ b/examples/aot/flash_attention/split_pipe/compile.sh @@ -0,0 +1,84 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# CANN Open Software License Agreement Version 2.0 +# +# AOT-compile the flash-attention kernel for one or more sequence lengths. +# +# Usage: +# bash compile.sh # build the default set: NUM_TILES = 16,32,64,128 +# # -> fa.so, fa_32.so, fa_64.so, fa_128.so +# # (NUM_TILES=16 → 8k seqlen → fa.so) +# FA_TILES=16,64 bash compile.sh # build only the listed NUM_TILES variants +# FA_TILES=16 bash compile.sh # single-variant build (legacy behavior) +# +# Each NUM_TILES value N produces fa${TAG}.{mlir,cpp,so} where +# TAG = "" if N == 16 (the builder's default → plain "fa.so") +# TAG = "_N" otherwise. + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ARTIFACT_DIR="${SCRIPT_DIR}/build_artifacts" +PTO_LIB_PATH="${PTO_LIB_PATH:-/sources/pto-isa}" + +mkdir -p "${ARTIFACT_DIR}" + +build_variant() { + local num_tiles="$1" + local tag + if [[ "${num_tiles}" == "16" ]]; then + tag="" + else + tag="_${num_tiles}" + fi + local mlir_path="${ARTIFACT_DIR}/fa${tag}.mlir" + local generated_cpp="${ARTIFACT_DIR}/fa${tag}.cpp" + local lib_path="${ARTIFACT_DIR}/fa${tag}.so" + + echo "==> Building NUM_TILES=${num_tiles} -> $(basename "${lib_path}")" + rm -f "${mlir_path}" "${generated_cpp}" "${lib_path}" + + FA_NUM_TILES="${num_tiles}" \ + FA_S1_TILE="${FA_S1_TILE:-512}" \ + FA_Q_ROWS="${FA_Q_ROWS:-3072}" \ + python "${SCRIPT_DIR}/kernels/fa_performance_builder.py" > "${mlir_path}" + ptoas --pto-arch=a3 --pto-level=level3 --enable-insert-sync "${mlir_path}" > "${generated_cpp}" + + bisheng \ + -I"${PTO_LIB_PATH}/include" \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + -cce-enable-mix \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + -DKERNEL_CPP="\"${generated_cpp}\"" \ + "${SCRIPT_DIR}/caller.cpp" \ + -o "${lib_path}" + + echo " built ${lib_path}" + { + echo "FA_NUM_TILES=${num_tiles}" + echo "FA_S1_TILE=${FA_S1_TILE:-512}" + echo "FA_Q_ROWS=${FA_Q_ROWS:-3072}" + } >"${ARTIFACT_DIR}/fa${tag}.build_env" +} + +# Default tile set covers seqlen = NUM_TILES * FA_S1_TILE (FA_S1_TILE defaults to 512). +# Examples at FA_S1_TILE=512: 16 -> 8k, 32 -> 16k, 64 -> 32k, 128 -> 64k. +# FA_Q_ROWS defaults to 3072; must match python run env when invoking kernels. +FA_TILES="${FA_TILES:-16,32,64,128}" + +IFS=',' read -r -a tile_list <<< "${FA_TILES}" +for nt in "${tile_list[@]}"; do + nt_trim="$(echo "${nt}" | tr -d '[:space:]')" + [[ -z "${nt_trim}" ]] && continue + build_variant "${nt_trim}" +done + +echo "Done. Built variants: ${FA_TILES}" diff --git a/examples/aot/flash_attention/split_pipe/kernels/fa_performance_builder.py b/examples/aot/flash_attention/split_pipe/kernels/fa_performance_builder.py new file mode 100644 index 00000000..795c2f41 --- /dev/null +++ b/examples/aot/flash_attention/split_pipe/kernels/fa_performance_builder.py @@ -0,0 +1,704 @@ +# Flash-Attention kernel builder. Ports the reference +# `fa_performance_kernel.cpp` (a2a3) software-pipelined schedule onto the +# pto-dsl multi-pipe primitives (ptoas >= 0.29). +# +# This file mirrors the reference C++ scheduler: +# +# constexpr int qkPreloadNum = 2; // warmup depth +# +# /* Prologue: cube emits QK[0..QK_PRELOAD-1]; vec consumes them and +# pushes P[0..QK_PRELOAD-1]. No PV / gu yet. */ +# +# /* Steady state, tile_id 0..N-1: +# cube: if (t+QK_PRELOAD < N) compute_qk(t+QK_PRELOAD); +# compute_pv(tile_id); +# vec: if (t+QK_PRELOAD < N) compute_p(t+QK_PRELOAD); +# compute_gu(tile_id); +# so vec's softmax for the LOOK-AHEAD tile fills the QK consumption +# slot WHILE the cube is computing the current PV[t]. The cube +# stops being blocked on a freshly-pushed P (softmax of t+2 has +# already pushed P[t+2] into the FIFO by the time cube needs it). */ +# +# /* Epilogue: drain the last QK_PRELOAD tiles' PV / gu. */ +# +# The new dependency that the reference solves with `l1_exp_max_ififo`: +# softmax(t+QK_PRELOAD) overwrites the running scratch tile `exp_max` +# (the rescale factor needed by gu(t)). With QK_PRELOAD=2 we therefore +# need a 2-deep ring of `exp_max` tiles (`exp_max_a`, `exp_max_b`). We +# implement the ring by unrolling the steady-state loop in pairs of 2 +# iterations: even iters use `exp_max_a`, odd iters use `exp_max_b`. +# +# Other state in the softmax (`new_global_max`, `new_global_sum`) does +# NOT need a ring: it is monotonic accumulator state across all tiles +# and is only read at the very end (the divide into o_tile). The fact +# that softmax(t+2) advances it ahead of gu(t) is harmless because gu +# never reads it. +# +# Sub-block (TILE_UP_DOWN) parallelism is preserved on every pipe op. +# +# Hardware-flag accounting (§3.5): 3 unidir pipes × 2 = 6 flags ≪ 16. + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +import math +import os + +const = s.const + +# --------------------------------------------------------------------------- +# Static shapes (must match run.py constants) +# --------------------------------------------------------------------------- +S0 = 32 # Q rows per block (vec uses S0_HALF rows; matches cpp_ref cube blocking story) +S0_HALF = S0 // 2 # rows per AIV sub-block +HEAD = 128 # attention head dimension — same as cpp_ref/split_pipe/run.py test_flash(..., head=128) +# cpp_ref benchmarks commonly use tile_s1=512; override with FA_S1_TILE for bring-up/debug (256 fits tighter UB). +S1_TILE = int(os.environ.get("FA_S1_TILE", "512")) +# NUM_TILES is overridable via the FA_NUM_TILES env var so the same builder +# can produce kernels for different sequence lengths +# (S1_TOTAL = S1_TILE * NUM_TILES). +# Constraint: (NUM_TILES - QK_PRELOAD) must be even (steady-state pair unroll). +NUM_TILES = int(os.environ.get("FA_NUM_TILES", "16")) + +S1_TOTAL = S1_TILE * NUM_TILES + +# Total Q rows: default 3072 = 128 * 24 matches cpp_ref benchmark `s0 = 128 * 24`. +Q_ROWS = int(os.environ.get("FA_Q_ROWS", "3072")) +NUM_Q_BLOCKS = Q_ROWS // S0 + +# QK preload depth — must be >= 1; reference uses 2. The vec pre-softmaxes +# tiles 0..QK_PRELOAD-1, then the steady-state loop interleaves softmax(t+QK_PRELOAD) +# with gu(t), and the epilogue drains the last QK_PRELOAD gu's. +# (NUM_TILES - QK_PRELOAD) must be even — steady state is pair-unrolled to +# ping-pong the exp_max ring (see below). +QK_PRELOAD = 2 +assert ( + NUM_TILES - QK_PRELOAD +) % 2 == 0, "Steady-state pair unrolling requires (NUM_TILES - QK_PRELOAD) % 2 == 0" +STEADY_PAIRS = (NUM_TILES - QK_PRELOAD) // 2 + +# Per-pipe slot sizes (bytes). +SLOT_SIZE_QK = S0 * S1_TILE * 4 # fp32 QK accumulator +SLOT_SIZE_PV = S0 * HEAD * 4 # fp32 PV accumulator +SLOT_SIZE_P = S0 * S1_TILE * 2 # fp16 P matrix sent vec → cube + +# `dir_mask = 1/2` always lowers to slot_num = 8 on a3 (design doc §4.4). +SLOT_NUM = 8 +# Kept at 1: bumping to 2 overflows VEC UB at S1_TILE=512. +QK_LOCAL_SLOT_NUM = 1 +# PV uses lower-level l2g2l_pipe with local_slot_num=1; the legacy +# aic/aiv_initialize_pipe path forces local = SLOT_NUM = 8 (32 KB MAT) +# whereas local=1 here is just 4 KB. +PV_LOCAL_SLOT_NUM = 1 + +# GM-staged FIFO bytes / fp32 elements per AIC block. +GM_BYTES_PER_BLOCK = (SLOT_SIZE_QK + SLOT_SIZE_PV + SLOT_SIZE_P) * SLOT_NUM +GM_ELEMS_PER_BLOCK = GM_BYTES_PER_BLOCK // 4 +GM_QK_OFF_F32 = 0 +GM_PV_OFF_F32 = (SLOT_SIZE_QK * SLOT_NUM) // 4 +GM_P_OFF_F32 = GM_PV_OFF_F32 + (SLOT_SIZE_PV * SLOT_NUM) // 4 + +FIFO_BYTES_QK = SLOT_SIZE_QK * QK_LOCAL_SLOT_NUM +FIFO_BYTES_PV = SLOT_SIZE_PV * PV_LOCAL_SLOT_NUM +FIFO_BYTES_P = SLOT_SIZE_P * SLOT_NUM + +# --pto-level=level3 requires explicit byte addresses for every reserve/alloc. +# Vector UB is 192 KiB; recv buffers for `tpop` must fit entirely below that limit. +_TILE_FP32_BYTES = S0_HALF * S1_TILE * 4 +_TILE_FP16_BYTES = S0_HALF * S1_TILE * 2 +_O_TILE_BYTES = S0_HALF * HEAD * 4 +_VEC_RED_STRIDE = 64 + +# Cube MAT / ACC / LEFT — fits in L1 with K/V sharing RIGHT at addr 0 (single tile footprint). +MAT_Q_OFF = 0 +MAT_K_OFF = MAT_Q_OFF + S0 * HEAD * 2 +MAT_P_RECV_OFF = MAT_K_OFF + HEAD * S1_TILE * 2 +MAT_V_OFF = MAT_P_RECV_OFF + S0 * S1_TILE * 2 +_MAT_CUBE_TAIL = MAT_V_OFF + S1_TILE * HEAD * 2 +_MAT_ALIGN = 65536 +MAT_P_FIFO_OFF = (_MAT_CUBE_TAIL + _MAT_ALIGN - 1) // _MAT_ALIGN * _MAT_ALIGN + +LEFT_Q_OFF = 0 +LEFT_P_OFF = LEFT_Q_OFF + S0 * HEAD * 2 + +RIGHT_KV_OFF = 0 + +ACC_QK_OFF = 0 +ACC_PV_OFF = ACC_QK_OFF + S0 * S1_TILE * 4 + +# Vec scratch (FIFOs then softmax stack). PV `tpop` reuses tmp_tile storage — pipeline serializes gu vs softmax per head32 schedule. +VEC_QK_FIFO_OFF = 0 +VEC_PV_FIFO_OFF = VEC_QK_FIFO_OFF + FIFO_BYTES_QK +_VEC_POST_FIFO_OFF = VEC_PV_FIFO_OFF + FIFO_BYTES_PV +VEC_QK_RECV_OFF = _VEC_POST_FIFO_OFF +VEC_TMP_OFF = VEC_QK_RECV_OFF + _TILE_FP32_BYTES +# GU consumes PV before softmax touches tmp again (pairs head32 schedule). +VEC_PV_RECV_OFF = VEC_TMP_OFF +VEC_P_FP32_OFF = VEC_TMP_OFF + _TILE_FP32_BYTES +# fp16 P lives in the **upper half** of the fp32 softmax tile — avoid bisheng assigning +# fp32/fp16 Tiles to the identical UB offset (duplicate TASSIGN breaks TCVT on device). +VEC_P_FP16_OFF = VEC_P_FP32_OFF + _TILE_FP16_BYTES +VEC_O_OFF = VEC_P_FP32_OFF + _TILE_FP32_BYTES +VEC_RED_BASE_OFF = VEC_O_OFF + _O_TILE_BYTES +VEC_NEW_GLOBAL_MAX_OFF = VEC_RED_BASE_OFF + 0 * _VEC_RED_STRIDE +VEC_LOCAL_MAX_OFF = VEC_RED_BASE_OFF + 1 * _VEC_RED_STRIDE +VEC_NEW_GLOBAL_SUM_OFF = VEC_RED_BASE_OFF + 2 * _VEC_RED_STRIDE +VEC_LOCAL_SUM_OFF = VEC_RED_BASE_OFF + 3 * _VEC_RED_STRIDE +VEC_EXP_MAX_A_OFF = VEC_RED_BASE_OFF + 4 * _VEC_RED_STRIDE +VEC_EXP_MAX_B_OFF = VEC_RED_BASE_OFF + 5 * _VEC_RED_STRIDE +_VEC_UB_TAIL = VEC_EXP_MAX_B_OFF + _VEC_RED_STRIDE + 128 +assert _VEC_UB_TAIL <= 192 * 1024, ( + f"VEC UB overflow tail={_VEC_UB_TAIL}; reduce S1_TILE or S0" +) +ID_QK = 10 # Cube → Vec, dir_mask = 1 (uses lower-level l2g2l) +ID_PV = 20 # Cube → Vec, dir_mask = 1 (legacy) +ID_P = 30 # Vec → Cube, dir_mask = 2 (legacy) + +SPLIT_UP_DOWN = 1 + + +# --------------------------------------------------------------------------- +# Type definitions (identical to multipipe builder) +# --------------------------------------------------------------------------- +def meta_data(): + fp16 = pto.float16 + fp32 = pto.float32 + ffts_ty = pto.ffts_type + ptr_fp16 = pto.PtrType(fp16) + ptr_fp32 = pto.PtrType(fp32) + i32 = pto.int32 + + qkv_tensor_ty = pto.TensorType(rank=2, dtype=fp16) + o_tensor_ty = pto.TensorType(rank=2, dtype=fp32) + + q_sub_ty = pto.SubTensorType(shape=[S0, HEAD], dtype=fp16) + kt_sub_ty = pto.SubTensorType(shape=[HEAD, S1_TILE], dtype=fp16) + v_sub_ty = pto.SubTensorType(shape=[S1_TILE, HEAD], dtype=fp16) + o_sub_ty = pto.SubTensorType(shape=[S0, HEAD], dtype=fp32) + o_sub_half_ty = pto.SubTensorType(shape=[S0_HALF, HEAD], dtype=fp32) + + # --- Cube tile types --- + q_mat_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp16, memory_space="MAT") + q_left_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp16, memory_space="LEFT") + k_mat_ty = pto.TileBufType( + shape=[HEAD, S1_TILE], + dtype=fp16, + memory_space="MAT", + config=pto.TileBufConfig(blayout="RowMajor", slayout="ColMajor"), + ) + k_right_ty = pto.TileBufType( + shape=[HEAD, S1_TILE], dtype=fp16, memory_space="RIGHT" + ) + qk_acc_ty = pto.TileBufType(shape=[S0, S1_TILE], dtype=fp32, memory_space="ACC") + p_recv_ty = pto.TileBufType( + shape=[S0, S1_TILE], + dtype=fp16, + memory_space="MAT", + ) + p_left_ty = pto.TileBufType(shape=[S0, S1_TILE], dtype=fp16, memory_space="LEFT") + v_mat_ty = pto.TileBufType(shape=[S1_TILE, HEAD], dtype=fp16, memory_space="MAT") + v_right_ty = pto.TileBufType( + shape=[S1_TILE, HEAD], dtype=fp16, memory_space="RIGHT" + ) + pv_acc_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp32, memory_space="ACC") + + # --- Vector tile types (HALF-size — split=1 on every pipe op) --- + qk_vec_ty = pto.TileBufType( + shape=[S0_HALF, S1_TILE], dtype=fp32, memory_space="VEC" + ) + p_fp32_ty = pto.TileBufType( + shape=[S0_HALF, S1_TILE], dtype=fp32, memory_space="VEC" + ) + p_fp16_ty = pto.TileBufType( + shape=[S0_HALF, S1_TILE], dtype=fp16, memory_space="VEC" + ) + pv_vec_ty = pto.TileBufType(shape=[S0_HALF, HEAD], dtype=fp32, memory_space="VEC") + red_ty = pto.TileBufType( + shape=[S0_HALF, 1], + dtype=fp32, + memory_space="VEC", + config=pto.TileBufConfig(blayout="ColMajor", slayout="NoneBox"), + ) + red_row_ty = pto.TileBufType( + shape=[1, S0_HALF], + dtype=fp32, + memory_space="VEC", + ) + o_vec_ty = pto.TileBufType(shape=[S0_HALF, HEAD], dtype=fp32, memory_space="VEC") + + return locals() + + +# --------------------------------------------------------------------------- +# Module +# --------------------------------------------------------------------------- +@to_ir_module(meta_data=meta_data, module=True) +def module(): + + # =================================================================== + # Cube kernel — PRELOAD=2 software-pipelined. + # =================================================================== + @pto.func(kernel="cube") + def cube_kernel( + gm_slot_buffer: "ptr_fp32", + gm_q: "ptr_fp16", + gm_k: "ptr_fp16", + gm_v: "ptr_fp16", + ) -> None: + c0 = const(0) + c1 = const(1) + c2 = const(2) + cS0 = const(S0) + cHEAD = const(HEAD) + cS1_TILE = const(S1_TILE) + cS1_TOTAL = const(S1_TOTAL) + cNUM_TILES = const(NUM_TILES) + cNUM_Q_BLOCKS = const(NUM_Q_BLOCKS) + + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + floor_div = cNUM_Q_BLOCKS // num_blocks + extra = cNUM_Q_BLOCKS % num_blocks + fat_start = bid * (floor_div + c1) + thin_start = extra * (floor_div + c1) + (bid - extra) * floor_div + qb_start = s.select(bid < extra, fat_start, thin_start) + q_blocks_this_core = s.select(bid < extra, floor_div + c1, floor_div) + qb_end = qb_start + q_blocks_this_core + + gm_blk_offset = bid * const(GM_ELEMS_PER_BLOCK) + gm_blk = pto.add_ptr(gm_slot_buffer, gm_blk_offset) + gm_qk = pto.add_ptr(gm_blk, const(GM_QK_OFF_F32)) + gm_pv = pto.add_ptr(gm_blk, const(GM_PV_OFF_F32)) + gm_p = pto.add_ptr(gm_blk, const(GM_P_OFF_F32)) + + # ---- Pipe QK_C2V (lower-level init) ---- + qk_c2v_import = pto.import_reserved_buffer( + name="fa_qk_c2v_fifo", peer_func="@vector_kernel" + ) + qk_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_QK, + slot_num=SLOT_NUM, + local_slot_num=QK_LOCAL_SLOT_NUM, + gm_addr=gm_qk, + local_addr=qk_c2v_import, + ) + + # ---- Pipe PV_C2V (lower-level init: PV_LOCAL_SLOT_NUM VEC slots) ---- + pv_c2v_import = pto.import_reserved_buffer( + name="fa_pv_c2v_fifo", peer_func="@vector_kernel" + ) + pv_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_PV, + slot_num=SLOT_NUM, + local_slot_num=PV_LOCAL_SLOT_NUM, + gm_addr=gm_pv, + local_addr=pv_c2v_import, + ) + + # ---- Pipe P_V2C (id = ID_P) ---- + p_v2c_local = pto.reserve_buffer( + name="fa_p_v2c_fifo", + size=FIFO_BYTES_P, + location="MAT", + auto_alloc=False, + base=MAT_P_FIFO_OFF, + ) + pto.aic_initialize_pipe( + id=ID_P, + dir_mask=2, + slot_size=SLOT_SIZE_P, + gm_slot_buffer=gm_p, + c2v_consumer_buf=const(0, s.int32), + v2c_consumer_buf=p_v2c_local, + nosplit=False, + ) + + right_base = const(RIGHT_KV_OFF, s.int64) + q_mat = pto.alloc_tile(q_mat_ty, addr=const(MAT_Q_OFF, s.int64)) + q_left = pto.alloc_tile(q_left_ty, addr=const(LEFT_Q_OFF, s.int64)) + k_mat_s = pto.alloc_tile(k_mat_ty, addr=const(MAT_K_OFF, s.int64)) + k_right_s = pto.alloc_tile(k_right_ty, addr=right_base) + qk_acc_s = pto.alloc_tile(qk_acc_ty, addr=const(ACC_QK_OFF, s.int64)) + p_recv_s = pto.alloc_tile(p_recv_ty, addr=const(MAT_P_RECV_OFF, s.int64)) + p_left_s = pto.alloc_tile(p_left_ty, addr=const(LEFT_P_OFF, s.int64)) + v_mat_s = pto.alloc_tile(v_mat_ty, addr=const(MAT_V_OFF, s.int64)) + v_right_s = pto.alloc_tile(v_right_ty, addr=right_base) + pv_acc_s = pto.alloc_tile(pv_acc_ty, addr=const(ACC_PV_OFF, s.int64)) + # Aliasing wrappers: keep the per-iteration `[buf]` indexing pattern + # in the body even though all slots currently point at one alloc. + k_mat = [k_mat_s, k_mat_s] + k_right = [k_right_s, k_right_s] + qk_acc = [qk_acc_s, qk_acc_s] + p_recv = [p_recv_s, p_recv_s] + p_left = [p_left_s, p_left_s] + v_mat = [v_mat_s, v_mat_s] + v_right = [v_right_s, v_right_s] + pv_acc = [pv_acc_s, pv_acc_s] + + cQ_ROWS = const(Q_ROWS) + tv_q = pto.as_tensor( + qkv_tensor_ty, ptr=gm_q, shape=[cQ_ROWS, cHEAD], strides=[cHEAD, c1] + ) + tv_k = pto.as_tensor( + qkv_tensor_ty, + ptr=gm_k, + shape=[cHEAD, cS1_TOTAL], + strides=[c1, cHEAD], + ) + tv_v = pto.as_tensor( + qkv_tensor_ty, + ptr=gm_v, + shape=[cS1_TOTAL, cHEAD], + strides=[cHEAD, c1], + ) + + for qb in pto.range(qb_start, qb_end, c1): + q_row_off = qb * cS0 + + q_view = pto.slice_view( + q_sub_ty, source=tv_q, offsets=[q_row_off, c0], sizes=[cS0, cHEAD] + ) + pto.load(q_view, q_mat) + tile.mov(q_mat, q_left) + + # =================== Cube prologue: emit QK[0..QK_PRELOAD-1] =================== + # Each prologue QK uses its own k_mat / k_right / qk_acc slot + # so MTE2 load of K[1] overlaps the M of QK[0]. + for k in range(QK_PRELOAD): + k_off = const(k * S1_TILE) + kt_view_k = pto.slice_view( + kt_sub_ty, + source=tv_k, + offsets=[c0, k_off], + sizes=[cHEAD, cS1_TILE], + ) + pto.load(kt_view_k, k_mat[k]) + tile.mov(k_mat[k], k_right[k]) + tile.matmul(q_left, k_right[k], qk_acc[k]) + pto.tpush(qk_acc[k], qk_pipe, SPLIT_UP_DOWN) + + # Preload V[0] for the very first PV. + v_view_0 = pto.slice_view( + v_sub_ty, + source=tv_v, + offsets=[c0, c0], + sizes=[cS1_TILE, cHEAD], + ) + pto.load(v_view_0, v_mat[0]) + + # =================== Cube steady state =================== + # Pair-unrolled. Iter t (parity = t%2 → buffer index `b`): + # * load K[next_qk = t+QK_PRELOAD] into k_mat[b] + # (next_qk parity equals t parity since QK_PRELOAD == 2) + # * pop / mov P[t] into p_left[b]; mov V[t] (in v_mat[b]) → v_right[b] + # * preload V[t+1] into v_mat[1-b] + # * matmul PV[t] into pv_acc[b]; push + # * matmul QK[next_qk] into qk_acc[b]; push + # Pair handler: + def emit_cube_step(t_idx, b): + # next_qk = t_idx + QK_PRELOAD (only used when in main range) + next_qk = t_idx + const(QK_PRELOAD) + kt_off = next_qk * cS1_TILE + kt_view = pto.slice_view( + kt_sub_ty, + source=tv_k, + offsets=[c0, kt_off], + sizes=[cHEAD, cS1_TILE], + ) + pto.load(kt_view, k_mat[b]) + + p_raw = pto.tpop_from_aiv(p_recv_ty, SPLIT_UP_DOWN, id=ID_P) + tile.mov(p_raw, p_left[b]) + pto.tfree_from_aiv(SPLIT_UP_DOWN, id=ID_P) + tile.mov(v_mat[b], v_right[b]) + + v_off = (t_idx + c1) * cS1_TILE + v_view = pto.slice_view( + v_sub_ty, + source=tv_v, + offsets=[v_off, c0], + sizes=[cS1_TILE, cHEAD], + ) + pto.load(v_view, v_mat[1 - b]) + + tile.matmul(p_left[b], v_right[b], pv_acc[b]) + pto.tpush(pv_acc[b], pv_pipe, SPLIT_UP_DOWN) + + tile.mov(k_mat[b], k_right[b]) + tile.matmul(q_left, k_right[b], qk_acc[b]) + pto.tpush(qk_acc[b], qk_pipe, SPLIT_UP_DOWN) + + assert (NUM_TILES - QK_PRELOAD) % 2 == 0 + for p in pto.range(c0, const((NUM_TILES - QK_PRELOAD) // 2), c1): + t_a = p * c2 + emit_cube_step(t_a, 0) + t_b = p * c2 + c1 + emit_cube_step(t_b, 1) + + # =================== Cube epilogue: drain last QK_PRELOAD PVs =================== + # Tile_id range: NUM_TILES-QK_PRELOAD .. NUM_TILES-1. + # NUM_TILES is even and QK_PRELOAD is even, so the first epilogue + # tile has parity 0. v_mat[0] holds V[NUM_TILES-QK_PRELOAD] thanks + # to the last steady-state preload (it loaded V[t_b+1] = V[NUM_TILES-QK_PRELOAD] + # into v_mat[1-1]=v_mat[0]). + for k in range(QK_PRELOAD): + b = k % 2 + p_raw = pto.tpop_from_aiv(p_recv_ty, SPLIT_UP_DOWN, id=ID_P) + tile.mov(p_raw, p_left[b]) + pto.tfree_from_aiv(SPLIT_UP_DOWN, id=ID_P) + tile.mov(v_mat[b], v_right[b]) + # Preload V[t+1] into the OPPOSITE slot, only if not the + # very last tile. + if k < QK_PRELOAD - 1: + next_v_idx = NUM_TILES - QK_PRELOAD + k + 1 + v_off_k = const(next_v_idx * S1_TILE) + v_view_k = pto.slice_view( + v_sub_ty, + source=tv_v, + offsets=[v_off_k, c0], + sizes=[cS1_TILE, cHEAD], + ) + pto.load(v_view_k, v_mat[1 - b]) + tile.matmul(p_left[b], v_right[b], pv_acc[b]) + pto.tpush(pv_acc[b], pv_pipe, SPLIT_UP_DOWN) + + # =================================================================== + # Vector kernel — PRELOAD=2 software-pipelined. + # =================================================================== + @pto.func(kernel="vector") + def vector_kernel( + gm_slot_buffer: "ptr_fp32", + gm_o: "ptr_fp32", + ) -> None: + c0 = const(0) + c1 = const(1) + c2 = const(2) + cS0 = const(S0) + cS0_HALF = const(S0_HALF) + cHEAD = const(HEAD) + cNUM_TILES = const(NUM_TILES) + cNUM_Q_BLOCKS = const(NUM_Q_BLOCKS) + + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + floor_div = cNUM_Q_BLOCKS // num_blocks + extra = cNUM_Q_BLOCKS % num_blocks + fat_start = bid * (floor_div + c1) + thin_start = extra * (floor_div + c1) + (bid - extra) * floor_div + qb_start = s.select(bid < extra, fat_start, thin_start) + q_blocks_this_core = s.select(bid < extra, floor_div + c1, floor_div) + qb_end = qb_start + q_blocks_this_core + + gm_blk_offset = bid * const(GM_ELEMS_PER_BLOCK) + gm_blk = pto.add_ptr(gm_slot_buffer, gm_blk_offset) + gm_qk = pto.add_ptr(gm_blk, const(GM_QK_OFF_F32)) + gm_pv = pto.add_ptr(gm_blk, const(GM_PV_OFF_F32)) + gm_p = pto.add_ptr(gm_blk, const(GM_P_OFF_F32)) + + # ---- Pipe QK_C2V ---- + qk_c2v_local = pto.reserve_buffer( + name="fa_qk_c2v_fifo", + size=FIFO_BYTES_QK, + location="VEC", + auto_alloc=False, + base=VEC_QK_FIFO_OFF, + ) + qk_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_QK, + slot_num=SLOT_NUM, + local_slot_num=QK_LOCAL_SLOT_NUM, + gm_addr=gm_qk, + local_addr=qk_c2v_local, + ) + + # ---- Pipe PV_C2V (lower-level init: PV_LOCAL_SLOT_NUM VEC slots) ---- + pv_c2v_local = pto.reserve_buffer( + name="fa_pv_c2v_fifo", + size=FIFO_BYTES_PV, + location="VEC", + auto_alloc=False, + base=VEC_PV_FIFO_OFF, + ) + pv_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_PV, + slot_num=SLOT_NUM, + local_slot_num=PV_LOCAL_SLOT_NUM, + gm_addr=gm_pv, + local_addr=pv_c2v_local, + ) + + # ---- Pipe P_V2C ---- + p_v2c_import = pto.import_reserved_buffer( + name="fa_p_v2c_fifo", peer_func="@cube_kernel" + ) + pto.aiv_initialize_pipe( + id=ID_P, + dir_mask=2, + slot_size=SLOT_SIZE_P, + gm_slot_buffer=gm_p, + c2v_consumer_buf=const(0, s.int32), + v2c_consumer_buf=p_v2c_import, + nosplit=False, + ) + + sb_idx = s.index_cast(pto.get_subblock_idx()) + row_off_sb = sb_idx * cS0_HALF + + tmp_tile = pto.alloc_tile(qk_vec_ty, addr=const(VEC_TMP_OFF, s.int64)) + p_fp32 = pto.alloc_tile(p_fp32_ty, addr=const(VEC_P_FP32_OFF, s.int64)) + p_fp16 = pto.alloc_tile(p_fp16_ty, addr=const(VEC_P_FP16_OFF, s.int64)) + o_tile = pto.alloc_tile(o_vec_ty, addr=const(VEC_O_OFF, s.int64)) + new_global_max = pto.alloc_tile( + red_ty, addr=const(VEC_NEW_GLOBAL_MAX_OFF, s.int64) + ) + local_max = pto.alloc_tile(red_ty, addr=const(VEC_LOCAL_MAX_OFF, s.int64)) + new_global_sum = pto.alloc_tile( + red_ty, addr=const(VEC_NEW_GLOBAL_SUM_OFF, s.int64) + ) + local_sum = pto.alloc_tile(red_ty, addr=const(VEC_LOCAL_SUM_OFF, s.int64)) + # Ring of QK_PRELOAD exp_max tiles. With QK_PRELOAD=2 we use a/b + # ping-pong: even-parity tiles use exp_max_a, odd-parity tiles use + # exp_max_b. softmax(t) writes the exp_max for tile t into the + # corresponding slot; gu(t) reads it from the same slot. Because + # softmax(t+QK_PRELOAD) and gu(t) hit the SAME slot (parity matches), + # the steady-state loop must do gu(t) BEFORE softmax(t+QK_PRELOAD) + # to avoid clobbering. + assert QK_PRELOAD == 2, "exp_max ring is hard-coded to 2 tiles" + exp_max_a = pto.alloc_tile(red_ty, addr=const(VEC_EXP_MAX_A_OFF, s.int64)) + exp_max_b = pto.alloc_tile(red_ty, addr=const(VEC_EXP_MAX_B_OFF, s.int64)) + + scale = const(1.0 / math.sqrt(HEAD), s.float32) + f32_one = const(1.0, s.float32) + + cQ_ROWS = const(Q_ROWS) + tv_o = pto.as_tensor( + o_tensor_ty, ptr=gm_o, shape=[cQ_ROWS, cHEAD], strides=[cHEAD, c1] + ) + + # Helper: emit a softmax step writing into `exp_max_slot`. + # `is_init` is a Python bool: True only for the very first softmax + # of the whole block (tile 0) to take the init branch. + def emit_softmax_step(exp_max_slot, is_init): + qk_recv = pto.tpop( + qk_vec_ty, + qk_pipe, + SPLIT_UP_DOWN, + addr=const(VEC_QK_RECV_OFF, s.int64), + ) + tile.muls(qk_recv, scale, qk_recv) + tile.row_max(qk_recv, tmp_tile, local_max) + + local_max_r = tile.reshape(red_row_ty, local_max) + new_global_max_r = tile.reshape(red_row_ty, new_global_max) + exp_max_r = tile.reshape(red_row_ty, exp_max_slot) + new_global_sum_r = tile.reshape(red_row_ty, new_global_sum) + local_sum_r = tile.reshape(red_row_ty, local_sum) + + if is_init: + tile.row_expand_sub(qk_recv, local_max, p_fp32) + tile.muls(local_max_r, f32_one, new_global_max_r) + tile.exp(p_fp32, p_fp32) + tile.row_sum(p_fp32, tmp_tile, new_global_sum) + else: + tile.max(local_max_r, new_global_max_r, local_max_r) + tile.sub(new_global_max_r, local_max_r, exp_max_r) + tile.muls(local_max_r, f32_one, new_global_max_r) + tile.row_expand_sub(qk_recv, local_max, p_fp32) + tile.exp(exp_max_r, exp_max_r) + tile.exp(p_fp32, p_fp32) + tile.mul(new_global_sum_r, exp_max_r, new_global_sum_r) + tile.row_sum(p_fp32, tmp_tile, local_sum) + tile.add(new_global_sum_r, local_sum_r, new_global_sum_r) + + tile.cvt(p_fp32, p_fp16) + pto.tpush_to_aic(p_fp16, SPLIT_UP_DOWN, id=ID_P) + pto.tfree(qk_pipe, SPLIT_UP_DOWN) + + # Helper: emit a gu step reading from `exp_max_slot`. + # `is_init` is a Python bool: True only for tile 0 (first PV). + def emit_gu_step(exp_max_slot, is_init): + pv_recv = pto.tpop( + pv_vec_ty, + pv_pipe, + SPLIT_UP_DOWN, + addr=const(VEC_PV_RECV_OFF, s.int64), + ) + if is_init: + tile.mov(pv_recv, o_tile) + else: + tile.row_expand_mul(o_tile, exp_max_slot, o_tile) + tile.add(o_tile, pv_recv, o_tile) + pto.tfree(pv_pipe, SPLIT_UP_DOWN) + + for qb in pto.range(qb_start, qb_end, c1): + o_row_off = qb * cS0 + + # =================== Vec prologue: softmax(0..QK_PRELOAD-1) =================== + # softmax(0): is_init=True (writes exp_max_a, but exp_max_a for tile 0 + # is unused by gu(0) — gu(0) takes the init branch and just movs PV. + # Still we must compute it correctly; the init branch doesn't touch exp_max. + emit_softmax_step(exp_max_a, is_init=True) + # softmax(1): is_init=False (writes exp_max_b) + emit_softmax_step(exp_max_b, is_init=False) + + # =================== Vec steady state =================== + # Pair-unrolled: each `p` iteration handles tiles t_a = 2p, t_b = 2p+1. + # gu(t_a) reads exp_max_a (set by softmax(t_a) earlier); + # softmax(t_a+2) writes exp_max_a (matches parity). + # gu(t_b) reads exp_max_b; softmax(t_b+2) writes exp_max_b. + # CRITICAL: gu BEFORE softmax in same step to avoid clobbering. + # + # First pair (p=0, t_a=0, t_b=1) is Python-unrolled so we can + # take the `is_init=True` branch on gu(0) (which initializes + # o_tile via mov rather than rescale+add). + emit_gu_step(exp_max_a, is_init=True) # tile 0 + emit_softmax_step(exp_max_a, is_init=False) # tile 2 → exp_max_a + emit_gu_step(exp_max_b, is_init=False) # tile 1 + emit_softmax_step(exp_max_b, is_init=False) # tile 3 → exp_max_b + + # Remaining pairs (p=1..STEADY_PAIRS-1) inside a runtime loop. + for p in pto.range(c1, const(STEADY_PAIRS), c1): + emit_gu_step(exp_max_a, is_init=False) + emit_softmax_step(exp_max_a, is_init=False) + emit_gu_step(exp_max_b, is_init=False) + emit_softmax_step(exp_max_b, is_init=False) + + # =================== Vec epilogue: gu(NUM_TILES-QK_PRELOAD..NUM_TILES-1) =================== + for k in range(QK_PRELOAD): + slot = exp_max_a if k % 2 == 0 else exp_max_b + emit_gu_step(slot, is_init=False) + + tile.row_expand_div(o_tile, new_global_sum, o_tile) + + o_row_off_sb = o_row_off + row_off_sb + o_view = pto.slice_view( + o_sub_half_ty, + source=tv_o, + offsets=[o_row_off_sb, c0], + sizes=[cS0_HALF, cHEAD], + ) + pto.store(o_tile, o_view) + + # =================================================================== + # Entry point + # =================================================================== + @pto.func + def call_both( + ffts_addr: "ffts_ty", + gm_slot_buffer: "ptr_fp32", + gm_q: "ptr_fp16", + gm_k: "ptr_fp16", + gm_v: "ptr_fp16", + gm_o: "ptr_fp32", + ) -> None: + pto.set_ffts(ffts_addr) + pto.call(cube_kernel, gm_slot_buffer, gm_q, gm_k, gm_v) + pto.call(vector_kernel, gm_slot_buffer, gm_o) + + +if __name__ == "__main__": + print(module) diff --git a/examples/aot/flash_attention/split_pipe/run.py b/examples/aot/flash_attention/split_pipe/run.py new file mode 100755 index 00000000..014ee83c --- /dev/null +++ b/examples/aot/flash_attention/split_pipe/run.py @@ -0,0 +1,429 @@ +#!/usr/bin/env python3 +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# CANN Open Software License Agreement Version 2.0 +# +# AOT runner for the multi-pipe FA builder. All .so variants must be built +# beforehand by `bash compile.sh`. This script only loads and invokes them. +# +# Compile/run parity: FA_NUM_TILES, FA_S1_TILE, and FA_Q_ROWS are baked into the +# emitted MLIR at compile time. Each compile emits matching fa*.build_env next to +# the .so; run.py applies the sidecar for the first FA_BENCH_LENGTHS entry and +# reloads fa_performance_builder (see README). Override FA_* in the environment +# if you skip compile.sh sidecars or resolve ambiguous seq_len factorizations. +# +# * Correctness check: the first requested benchmark length. +# * Benchmark: 8k / 16k / 32k / 64k variants by default. Override with +# FA_BENCH_LENGTHS=8192,32768 (each length must be a multiple of S1_TILE). + +import ctypes +import importlib +import os +import sys +import math + +import torch +import torch_npu # noqa: F401 + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(THIS_DIR, "kernels")) +import fa_performance_builder as fb # noqa: E402 — reloaded in main() after build_env sync + +from ptodsl import do_bench # noqa: E402 +from ptodsl.utils.npu_info import get_num_cube_cores, get_test_device # noqa: E402 + +ARTIFACT_DIR = os.path.join(THIS_DIR, "build_artifacts") +DEFAULT_PLOT_PATH = os.path.join(ARTIFACT_DIR, "fa_benchmark.png") + +# Sequence lengths benchmarked. Override with +# FA_BENCH_LENGTHS=8192,32768 +# Each must be a multiple of S1_TILE (from reload/synced fa_performance_builder) +# and have a matching prebuilt .so. +DEFAULT_BENCH_LENGTHS = (8192, 16384, 32768, 65536) + + +def _parse_bench_lengths(): + raw = os.environ.get("FA_BENCH_LENGTHS") + if not raw: + return DEFAULT_BENCH_LENGTHS + return tuple(int(x) for x in raw.split(",") if x.strip()) + + +ATOL = 1e-3 +RTOL = 1e-3 + + +def attn_flops_matmul_softmax_scale( + batch_size: int, + s_q: int, + s_k: int, + h: int, + include_scale: bool = True, + count_exp_as_flop: bool = True, + count_max_as_flop: bool = True, +): + """Same FLOP model as `cpp_ref/split_pipe/run.py` for comparable TFLOP/s.""" + flops_matmul = 4 * batch_size * s_q * s_k * h + flops_scale = (batch_size * s_q * s_k) if include_scale else 0 + + rows = batch_size * s_q + softmax_ops = 0 + if count_max_as_flop: + softmax_ops += rows * (s_k - 1) + softmax_ops += rows * s_k + if count_exp_as_flop: + softmax_ops += rows * s_k + softmax_ops += rows * (s_k - 1) + softmax_ops += rows * s_k + + return flops_matmul + flops_scale + softmax_ops + + +def get_block_dim() -> int: + return min(fb.NUM_Q_BLOCKS, get_num_cube_cores()) + + +def get_slot_elems(block_dim: int) -> int: + return fb.GM_ELEMS_PER_BLOCK * block_dim + + +def num_tiles_for(seq_len: int) -> int: + s1_tile = fb.S1_TILE + if seq_len % s1_tile != 0: + raise ValueError(f"seq_len {seq_len} not divisible by S1_TILE={s1_tile}") + return seq_len // s1_tile + + +def _apply_build_env_matching_seq(seq_len: int) -> None: + """Set FA_* env vars from the fa*.build_env written by compile.sh for this seq_len.""" + if not os.path.isdir(ARTIFACT_DIR): + return + matches = [] + for fn in sorted(os.listdir(ARTIFACT_DIR)): + if not fn.endswith(".build_env"): + continue + path = os.path.join(ARTIFACT_DIR, fn) + kv = {} + try: + with open(path, encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line or line.startswith("#"): + continue + k, _, v = line.partition("=") + if k.strip(): + kv[k.strip()] = v.strip() + except OSError: + continue + try: + nt = int(kv["FA_NUM_TILES"]) + s1 = int(kv["FA_S1_TILE"]) + except (KeyError, ValueError): + continue + if nt * s1 != seq_len: + continue + matches.append((fn, kv)) + + if not matches: + return + + # Same seq_len can factor two ways (e.g. 8192 = 16×512 = 32×256): prefer + # FA_S1_TILE from the environment when set; else prefer canonical fa.build_env. + env_s1 = os.environ.get("FA_S1_TILE") + if env_s1 is not None: + for fn, kv in matches: + if kv.get("FA_S1_TILE") == env_s1.strip(): + for k, v in kv.items(): + os.environ[k] = v + return + + matches.sort(key=lambda x: (0 if x[0] == "fa.build_env" else 1, x[0])) + for _, kv in matches: + for k, v in kv.items(): + os.environ[k] = v + return + + +def lib_path_for(num_tiles: int) -> str: + # NUM_TILES=16 is the builder default and produces plain fa.so. + if num_tiles == 16: + return os.path.join(ARTIFACT_DIR, "fa.so") + return os.path.join(ARTIFACT_DIR, f"fa_{num_tiles}.so") + + +def require_lib(num_tiles: int) -> str: + """Return the prebuilt .so path for the variant, or raise.""" + lib_path = lib_path_for(num_tiles) + if not os.path.exists(lib_path): + raise FileNotFoundError( + f"Missing prebuilt kernel: {lib_path}\n" + f"Run `bash compile.sh` (or `FA_TILES={num_tiles} bash compile.sh`) first." + ) + return lib_path + + +def torch_to_ctypes(t: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(t.data_ptr()) + + +def load_lib(lib_path: str) -> ctypes.CDLL: + lib = ctypes.CDLL(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ] + lib.call_kernel.restype = None + return lib + + +def fa_reference(q, k, v): + scale = 1.0 / math.sqrt(q.shape[1]) + scores = q.float() @ k.float().T * scale + attn = torch.softmax(scores, dim=-1) + return (attn @ v.float()).float() + + +def fused_attention(q, k, v, is_causal=False): + scale = 1.0 / math.sqrt(q.shape[1]) + out, _ = torch_npu.npu_fused_infer_attention_score( + q.unsqueeze(0), + k.unsqueeze(0), + v.unsqueeze(0), + num_heads=1, + input_layout="BSH", + scale=scale, + next_tokens=0 if is_causal else 65535, + ) + return out.squeeze(0) + + +def test_flash(lib, device, num_tiles): + torch.manual_seed(0) + Q_ROWS = fb.Q_ROWS + HEAD = fb.HEAD + S1_TOTAL = fb.S1_TILE * num_tiles + GM_ELEMS_PER_BLOCK = fb.GM_ELEMS_PER_BLOCK + + block_dim = get_block_dim() + slot_elems = get_slot_elems(block_dim) + + q = torch.randn((Q_ROWS, HEAD), dtype=torch.float16, device=device) + k = torch.randn((S1_TOTAL, HEAD), dtype=torch.float16, device=device) + v = torch.randn((S1_TOTAL, HEAD), dtype=torch.float16, device=device) + + gm_slot = torch.zeros((slot_elems,), dtype=torch.float32, device=device) + o = torch.zeros((Q_ROWS, HEAD), dtype=torch.float32, device=device) + + stream_ptr = torch.npu.current_stream()._as_parameter_ + + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(gm_slot), + torch_to_ctypes(q), + torch_to_ctypes(k), + torch_to_ctypes(v), + torch_to_ctypes(o), + ) + torch.npu.synchronize() + + o_ref = fa_reference(q, k, v) + torch.testing.assert_close(o.cpu().float(), o_ref.cpu(), rtol=RTOL, atol=ATOL) + print( + f"[fa] q_rows={Q_ROWS} s1={S1_TOTAL} head={HEAD} " + f"({num_tiles} tiles, blockDim={block_dim}): PASSED " + f"(atol={ATOL}, rtol={RTOL}) GM/blk={GM_ELEMS_PER_BLOCK} fp32" + ) + + +def benchmark_flash(lib, device, num_tiles, warmup=10, iters=100): + """Benchmark a single (length-tagged) .so. Returns dict of metrics.""" + torch.manual_seed(0) + Q_ROWS = fb.Q_ROWS + HEAD = fb.HEAD + S1_TILE = fb.S1_TILE + s1_total = S1_TILE * num_tiles + + block_dim = get_block_dim() + slot_elems = get_slot_elems(block_dim) + + q = torch.randn((Q_ROWS, HEAD), dtype=torch.float16, device=device) + k = torch.randn((s1_total, HEAD), dtype=torch.float16, device=device) + v = torch.randn((s1_total, HEAD), dtype=torch.float16, device=device) + + gm_slot = torch.zeros((slot_elems,), dtype=torch.float32, device=device) + o = torch.zeros((Q_ROWS, HEAD), dtype=torch.float32, device=device) + stream_ptr = torch.npu.current_stream()._as_parameter_ + + def run_kernel(): + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(gm_slot), + torch_to_ctypes(q), + torch_to_ctypes(k), + torch_to_ctypes(v), + torch_to_ctypes(o), + ) + + def run_reference(): + fused_attention(q, k, v) + + kernel_us = do_bench( + run_kernel, + warmup_iters=warmup, + benchmark_iters=iters, + unit="us", + flush_cache=False, + ) + ref_us = do_bench( + run_reference, + warmup_iters=warmup, + benchmark_iters=iters, + unit="us", + flush_cache=False, + ) + + # One untimed correctness check per length: assert against the fp32 + # reference so silent miscompiles fail loudly instead of just showing + # a large max|err| in the summary table. + run_kernel() + torch.npu.synchronize() + o_kernel = o.clone() + o_fused = fused_attention(q, k, v) + torch.npu.synchronize() + o_golden = fa_reference(q, k, v) + + diff_kernel = (o_kernel.cpu().float() - o_golden.cpu()).abs().max().item() + diff_fused = (o_fused.cpu().float() - o_golden.cpu()).abs().max().item() + torch.testing.assert_close( + o_kernel.cpu().float(), o_golden.cpu(), rtol=RTOL, atol=ATOL + ) + + flops = attn_flops_matmul_softmax_scale(1, Q_ROWS, s1_total, HEAD) + return { + "seq_len": s1_total, + "num_tiles": num_tiles, + "block_dim": block_dim, + "kernel_us": kernel_us, + "ref_us": ref_us, + "kernel_tflops": flops / (kernel_us * 1e-6) / 1e12, + "ref_tflops": flops / (ref_us * 1e-6) / 1e12, + "speedup": ref_us / kernel_us, + "kernel_max_err": diff_kernel, + "fused_max_err": diff_fused, + } + + +def print_bench_row(r): + print( + f" s1={r['seq_len']:>6} tiles={r['num_tiles']:>3} " + f"fa={r['kernel_us']:8.2f} us ({r['kernel_tflops']:7.3f} TFLOP/s) " + f"ref={r['ref_us']:8.2f} us ({r['ref_tflops']:7.3f} TFLOP/s) " + f"speedup={r['speedup']:.2f}x " + f"err: ours={r['kernel_max_err']:.2e} ref={r['fused_max_err']:.2e}" + ) + + +def plot_benchmark_results(results, out_png=None): + """Save a throughput plot for PTO FA vs torch_npu reference.""" + if not results: + return + + out_png = out_png or os.environ.get("FA_BENCH_PLOT_PATH", DEFAULT_PLOT_PATH) + + try: + import matplotlib + + matplotlib.use("Agg") + import matplotlib.pyplot as plt + except ImportError: + print("Warning: matplotlib is not installed; skipping plot generation.") + return + + style_candidates = ("seaborn-v0_8-whitegrid", "seaborn-whitegrid") + for style_name in style_candidates: + try: + plt.style.use(style_name) + break + except OSError: + continue + + seq_lens = [r["seq_len"] for r in results] + fa_tflops = [r["kernel_tflops"] for r in results] + ref_tflops = [r["ref_tflops"] for r in results] + + fig, ax_thr = plt.subplots(figsize=(7, 5)) + fig.patch.set_facecolor("white") + + ax_thr.plot(seq_lens, fa_tflops, "o-", label="PTO flash attention") + ax_thr.plot(seq_lens, ref_tflops, "s-", label="torch_npu fused attention") + ax_thr.set_title("Throughput") + ax_thr.set_xlabel("S1 sequence length") + ax_thr.set_ylabel("TFLOP/s") + ax_thr.legend() + ax_thr.set_xscale("log", base=2) + ax_thr.set_xticks(seq_lens) + ax_thr.set_xticklabels([str(x) for x in seq_lens], rotation=30) + + fig.suptitle( + f"Flash Attention Benchmark: Q={fb.Q_ROWS}, H={fb.HEAD}, " + f"S1_TILE={fb.S1_TILE}" + ) + fig.tight_layout() + + out_dir = os.path.dirname(out_png) + if out_dir: + os.makedirs(out_dir, exist_ok=True) + fig.savefig(out_png, dpi=180) + plt.close(fig) + print(f"Saved benchmark plot: {out_png}") + + +def main(): + global fb + device = get_test_device() + torch.npu.set_device(device) + + bench_lengths = _parse_bench_lengths() + _apply_build_env_matching_seq(bench_lengths[0]) + fb = importlib.reload(fb) + + # Verify all required .so artifacts exist before doing anything. + required = [(seq_len, num_tiles_for(seq_len)) for seq_len in bench_lengths] + for seq_len, nt in required: + require_lib(nt) + + # ---- correctness on the first requested benchmark variant ---- + _, first_nt = required[0] + default_lib = load_lib(require_lib(first_nt)) + test_flash(default_lib, device, num_tiles=first_nt) + + # ---- benchmark across requested sequence lengths ---- + print(f"\n{'Benchmark (fa)':=^96}") + print( + f" Q_ROWS={fb.Q_ROWS} HEAD={fb.HEAD} " + f"S1_TILE={fb.S1_TILE} " + f"NUM_Q_BLOCKS={fb.NUM_Q_BLOCKS} cores={get_num_cube_cores()}" + ) + print(f" lengths: {list(bench_lengths)}") + print("-" * 96) + + results = [] + for seq_len, nt in required: + lib = load_lib(require_lib(nt)) + r = benchmark_flash(lib, device, num_tiles=nt) + print_bench_row(r) + results.append(r) + print("=" * 96) + + if os.environ.get("FA_BENCH_NO_PLOT", "").lower() not in ("1", "true", "yes"): + plot_benchmark_results(results) + + +if __name__ == "__main__": + main() diff --git a/ptodsl/api/pto_general.py b/ptodsl/api/pto_general.py index a7f25593..a897f2d1 100644 --- a/ptodsl/api/pto_general.py +++ b/ptodsl/api/pto_general.py @@ -217,9 +217,9 @@ def initialize_l2g2l_pipe( slot_size, slot_num, _unwrap(gm_addr), - _unwrap(local_addr), local_slot_num=local_slot_num, flag_base=flag_base, + local_addr=_unwrap(local_addr), peer_local_addr=( _unwrap(peer_local_addr) if peer_local_addr is not None else None ),