Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions agent_temp/a5_optim_plan.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
---
name: megagdn a5 optimization
overview: Optimize the already-compiling MegaGDN A5 PTO kernels for real Ascend950PR performance, focusing on replacing GM Cube-Vector handoffs in `chunk_h` and `chunk_o` with direct A5 paths and benchmarking each attempted variant against the saved A2 PTO baseline.
todos:
- id: profile-current-hotspots
content: Map current `chunk_h` and `chunk_o` GM workspace handoffs and estimate bytes/FLOPs per handoff.
status: completed
- id: explore-a5-manual-patterns
content: "Explore all relevant A5 manual optimization patterns: flash attention direct C/V modes, gemm_ar L1 reuse, and SIMT/D-cache guidance."
status: completed
- id: analyze-buffer-core-util
content: Analyze A5 UB/L1/L0C/core utilization for current `chunk_h` and `chunk_o` and identify safe double-buffer opportunities.
status: completed
- id: prototype-chunk-h-c2v
content: Prototype direct `L0C -> UB` replacement for `chunk_h` Cube-to-Vec WS/KV handoff and validate correctness.
status: completed
- id: prototype-chunk-h-v2c
content: Prototype direct `UB -> L1` replacement for `chunk_h` Vec-to-Cube K/S handoff and validate correctness.
status: completed
- id: prototype-chunk-o-c2v
content: Prototype direct `L0C -> UB` replacement for `chunk_o` QK/QS/QKV handoffs and validate correctness.
status: completed
- id: prototype-chunk-o-v2c
content: Prototype direct `UB -> L1` replacement for `chunk_o` QK_gated handoff and validate correctness.
status: completed
- id: optimize-kkt
content: Explore A5-specific optimization candidates for `scaled_dot_kkt`, including direct C/V paths and L1/UB reuse where applicable.
status: completed
- id: optimize-wy-fast
content: Explore A5-specific optimization candidates for `wy_fast`, including direct handoff, L1 panel reuse, and H=64 timeout mitigation.
status: completed
- id: explore-simt-memory-bound
content: Evaluate whether SIMT/D-cache techniques are applicable to memory-bound scalar or elementwise sections after direct C/V experiments.
status: completed
- id: benchmark-variants
content: Benchmark all passing variants on real `npu:0` with reduced and large shapes, avoiding unsafe H=64 paths until stable.
status: completed
- id: select-best-report
content: Keep the best-performing correct variants and write the final A5 optimization report and lessons learned.
status: completed
isProject: false
---

# MegaGDN A5 Performance Optimization Plan

## Goal
- Improve A5 PTO kernel performance beyond the current mechanical port in [`/home/jzhuang/megagdn-pto/kernels/pto_a5`](/home/jzhuang/megagdn-pto/kernels/pto_a5).
- Prioritize `chunk_h` and `chunk_o`, then check `scaled_dot_kkt` and `wy_fast` if time permits.
- Leave PTO `tri_inverse` / `solve_tril` for future work, except keeping the current torch fallback for correctness.
- Target: approach or exceed `3x` speedup versus saved A2 PTO timings in [`/home/jzhuang/megagdn-pto/outputs/data/kernel_bench.json`](/home/jzhuang/megagdn-pto/outputs/data/kernel_bench.json), while preserving numerical correctness on real `npu:0`.

## Baseline To Preserve
- Current A5 correctness port lives in [`/home/jzhuang/megagdn-pto/kernels/pto_a5`](/home/jzhuang/megagdn-pto/kernels/pto_a5).
- Current A5 benchmark/comparison artifacts:
- [`outputs/data/kernel_bench_a5.json`](/home/jzhuang/megagdn-pto/outputs/data/kernel_bench_a5.json)
- [`outputs/data/kernel_bench_a5_comparison.md`](/home/jzhuang/megagdn-pto/outputs/data/kernel_bench_a5_comparison.md)
- Known limitation: H=64 large-shape `wy_fast` timed out. Avoid unsafe reruns until smaller variants validate.

## Optimization Candidates
- Direct C2V in `chunk_h`:
- Replace Cube `TSTORE WS = W @ S -> GM workspace` plus Vec `TLOAD WS <- GM` with A5 `L0C -> UB` using `TMOV` / `copy_matrix_cc_to_ub`.
- Use explicit ready/free sync so Cube does not overwrite the UB handoff before Vec consumes it.
- Direct V2C in `chunk_h`:
- Replace Vec `TSTORE K_scaled/S -> GM workspace` plus Cube `TLOAD -> L1` with `TINSERT` / `copy_ubuf_to_cbuf` into L1.
- Convert Vec ND tiles to NZ before Cube consumption.
- Use one-slot conservative sync first; only then try double-buffering.
- Direct C2V in `chunk_o`:
- Replace Cube QK/QS/QKV GM workspace stores with direct `L0C -> UB` where Vec immediately gates/combines outputs.
- Start with QKV or QS only if full replacement is too risky.
- Direct V2C in `chunk_o`:
- Replace Vec QK_gated GM workspace store with direct `UB -> L1` via `TINSERT` before Cube GEMM3.
- Apply the verified `stream_v2c` and `add_matmul_v2c` ownership pattern from [`/home/jzhuang/pto-kernels-fork/examples/jit_cpp/cv_sync_demo_a5`](/home/jzhuang/pto-kernels-fork/examples/jit_cpp/cv_sync_demo_a5).
- A5 buffer/core utilization:
- Re-check UB/L1/L0C footprint against DAV_3510 capacities from [`/home/jzhuang/cannbot-skills/ops/npu-arch/SKILL.md`](/home/jzhuang/cannbot-skills/ops/npu-arch/SKILL.md): UB ~248KB, L0C 256KB, Cube cores 28/32 depending SKU.
- Increase local double-buffering only where capacity allows.
- Avoid GM scratch when direct local buffers fit.
- Advanced patterns to inspect and reuse:
- Manual A5 flash attention direct C/V modes in [`/home/jzhuang/pto-isa/kernels/manual/a5/flash_atten`](/home/jzhuang/pto-isa/kernels/manual/a5/flash_atten).
- A5 matmul/L1 reuse patterns in [`/home/jzhuang/pto-isa/kernels/manual/a5/gemm_ar`](/home/jzhuang/pto-isa/kernels/manual/a5/gemm_ar).
- SIMT/D-cache ideas only for memory-bound scalar/elementwise sections if direct C/V is insufficient.

## Execution Strategy
- Work in small variants, one kernel at a time:
- Create/modify one candidate path.
- Run quick correctness for `H=16`.
- Run quick correctness for `H=16,32,48,64` if safe.
- Benchmark only that stage at reduced iterations.
- Keep or revert based on correctness and measured speed.
- Prefer preserving working A5 code with compile-time switches for experiments, e.g. `GDN_A5_DIRECT_CV_CHUNK_H`, until best variant is selected.
- Avoid full H=64 large-shape benchmarks until H=16/32/48 are stable.

## Validation Commands
- Environment:
- `conda activate torch_npu_dev`
- `source /usr/local/Ascend/cann-9.0.0/set_env.sh`
- `export GDN_NPU_DEVICE=npu:0`
- `export MEGAGDN_PTO_ARCH=a5`
- Quick correctness per stage:
- `python3 tests/test_single_kernels.py --device npu:0 --quick --H-list 16 --stage chunk_h`
- `python3 tests/test_single_kernels.py --device npu:0 --quick --H-list 16 --stage chunk_o`
- Broader correctness after a variant passes:
- `python3 tests/test_single_kernels.py --device npu:0 --quick --H-list 16,32,48,64 --stage chunk_h,chunk_o`
- Stage benchmark examples:
- `GDN_BENCH_WARMUP=1 GDN_BENCH_ITERS=3 python3 benchmarks/kernel/bench_gdn_kernels.py --device npu:0 --n-seq 16 --l-seg 16384 --H-list 16,32,48 --stage chunk_h,chunk_o --output-json outputs/data/kernel_bench_a5_opt.json`

## Reporting
- Produce a final optimization report, for example [`outputs/data/kernel_bench_a5_opt_report.md`](/home/jzhuang/megagdn-pto/outputs/data/kernel_bench_a5_opt_report.md), containing:
- Variant list and whether each passed correctness.
- Best timings by stage and H.
- Speedup vs A2 baseline.
- Which direct C/V exchanges were successfully eliminated from GM.
- Any failed attempts and why.
- If any kernel reaches or exceeds `3x` vs A2, add a short “A5 optimization practices learned” section to [`kernels/pto_a5/PORT_STATUS.md`](/home/jzhuang/megagdn-pto/kernels/pto_a5/PORT_STATUS.md) or a new notes file.

## Guardrails
- Do not work on PTO `tri_inverse` in this optimization pass.
- Do not count torch fallback solve time as PTO performance.
- Do not claim speedups from noisy tiny-stage timings like `chunk_cumsum` unless measurement is robust.
- If a variant triggers AICore timeout, stop that path, verify device health, and document it rather than repeatedly rerunning unsafe cases.
45 changes: 41 additions & 4 deletions benchmarks/kernel/bench_gdn_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@

C_PTO = 128
D = 128
PTO_ONLY = True


# ---------------------------------------------------------------------------
Expand All @@ -85,6 +86,8 @@

def _bench_npu(fn, warmup: int = 5, iters: int = 15) -> float:
"""Time an NPU function using Event pairs (ms)."""
warmup = int(os.getenv("GDN_BENCH_WARMUP", str(warmup)))
iters = int(os.getenv("GDN_BENCH_ITERS", str(iters)))
starts = [torch.npu.Event(enable_timing=True) for _ in range(iters)]
ends = [torch.npu.Event(enable_timing=True) for _ in range(iters)]
cache = torch.empty(256 * 1024 * 1024, dtype=torch.int8).npu()
Expand Down Expand Up @@ -133,6 +136,8 @@ def _ratio(ms_triton: float | None, ms_pto: float) -> str:
# ---------------------------------------------------------------------------

def _try_triton_cumsum(cu_seqlens, BT, dev, T, H) -> float | None:
if PTO_ONLY:
return None
try:
from fla_vendor.cumsum import chunk_local_cumsum
from fla_vendor.utils import prepare_chunk_indices
Expand All @@ -148,6 +153,8 @@ def _try_triton_cumsum(cu_seqlens, BT, dev, T, H) -> float | None:


def _try_triton_kkt(cu_seqlens, BT, dev, T, H, HG) -> float | None:
if PTO_ONLY:
return None
try:
from fla_vendor.chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
from fla_vendor.utils import prepare_chunk_indices
Expand All @@ -168,6 +175,8 @@ def _try_triton_kkt(cu_seqlens, BT, dev, T, H, HG) -> float | None:


def _try_triton_solve_tril(cu_seqlens, BT, dev, T, H) -> float | None:
if PTO_ONLY:
return None
if BT > 64:
print(f" [Triton solve_tril BT={BT}: not supported by Triton (max BT=64)]")
return None
Expand All @@ -191,6 +200,8 @@ def _try_triton_solve_tril(cu_seqlens, BT, dev, T, H) -> float | None:


def _try_triton_chunk_h(cu_seqlens, BT, dev, T, H, HG) -> float | None:
if PTO_ONLY:
return None
# H=64 with BT=64 triggers an aicore exception on this NPU, corrupting device state.
if H >= 64 and BT <= 64:
print(f" [Triton chunk_h BT={BT} H={H}: known aicore failure, skip]")
Expand Down Expand Up @@ -218,6 +229,8 @@ def _try_triton_chunk_h(cu_seqlens, BT, dev, T, H, HG) -> float | None:


def _try_triton_wy_fast(cu_seqlens, BT, dev, T, H, HG) -> float | None:
if PTO_ONLY:
return None
try:
from fla_vendor.wy_fast import recompute_w_u_fwd
from fla_vendor.utils import prepare_chunk_indices
Expand All @@ -239,6 +252,8 @@ def _try_triton_wy_fast(cu_seqlens, BT, dev, T, H, HG) -> float | None:


def _try_triton_chunk_o(cu_seqlens, BT, dev, T, H, HG) -> float | None:
if PTO_ONLY:
return None
# H=64 with BT=64 is a known aicore failure; skip to protect NPU state.
if H >= 64 and BT <= 64:
print(f" [Triton chunk_o BT={BT} H={H}: known aicore failure, skip]")
Expand Down Expand Up @@ -335,10 +350,24 @@ def bench_solve_tril(H, T, cu_seqlens, dev, tri_inv):
Reduce ``L_seg`` (e.g. 8192 for H≤16 full parity; 4096 for H≤32) to benchmark
the Triton reference without grid overflow.
"""
_ = tri_inv # preload contract shared with staged/mega callers

A = torch.zeros(1, T, H, C_PTO, device=dev, dtype=torch.float16).tril(-1)
cu32 = cu_seqlens.to(torch.int32)
if os.environ.get("MEGAGDN_PTO_ARCH", "").lower() in {"a5", "dav3510", "dav_3510", "ascend950"}:
out = torch.empty_like(A)

def run_solve_fallback():
solve_tril(A, cu32, C_PTO, H, tri_inv, out_fp16=out)

run_solve_fallback()
torch.npu.synchronize()
ms_pto = _bench_npu(run_solve_fallback, warmup=1, iters=3)
ms_t64 = _try_triton_solve_tril(cu_seqlens, 64, dev, T, H)
ms_t128 = None
_print_stage("solve_tril_torch_fallback", ms_pto, ms_t64, ms_t128)
return ms_pto, ms_t64, ms_t128

_ = tri_inv # preload contract shared with staged/mega callers

workspace_fp32 = torch.zeros_like(A, dtype=torch.float32)
batch = int(cu32.numel()) - 1
tc = total_chunks(batch, T, C_PTO, cu32)
Expand Down Expand Up @@ -367,7 +396,8 @@ def run_tri_inverse_kernel():
ms_pto = _bench_npu(run_tri_inverse_kernel)
ms_t64 = _try_triton_solve_tril(cu_seqlens, 64, dev, T, H)
ms_t128 = None # BT=128 not supported by Triton (max BT=64)
print(f" [Triton solve_tril BT=128: not supported (max BT=64)]")
if not PTO_ONLY:
print(f" [Triton solve_tril BT=128: not supported (max BT=64)]")
_print_stage("solve_tril", ms_pto, ms_t64, ms_t128)
return ms_pto, ms_t64, ms_t128

Expand Down Expand Up @@ -541,6 +571,7 @@ def run_staged():
# ---------------------------------------------------------------------------

def main() -> None:
global PTO_ONLY
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--device", default=os.getenv("GDN_NPU_DEVICE", "npu:0"))
parser.add_argument("--n-seq", type=int, default=None,
Expand All @@ -554,9 +585,12 @@ def main() -> None:
default="cumsum,kkt,solve_tril,wy_fast,chunk_h,chunk_o",
help="Comma-separated stages to benchmark.")
parser.add_argument("--mega", action="store_true", help="Also benchmark mega-kernel.")
parser.add_argument("--with-triton-baseline", action="store_true",
help="Also try Triton baselines. Off by default for the current A5 environment.")
parser.add_argument("--output-json", default=None,
help="Save results as JSON to this path.")
args = parser.parse_args()
PTO_ONLY = not args.with_triton_baseline

torch.manual_seed(0)
torch.npu.set_device(args.device)
Expand All @@ -575,7 +609,8 @@ def main() -> None:
heads_list = [int(x) for x in args.H_list.split(",") if x.strip()]
HG = args.hg

tri_inv_needed = args.mega or "solve_tril" in {s.strip() for s in args.stage.split(",")}
is_a5 = os.environ.get("MEGAGDN_PTO_ARCH", "").lower() in {"a5", "dav3510", "dav_3510", "ascend950"}
tri_inv_needed = args.mega or ("solve_tril" in {s.strip() for s in args.stage.split(",")} and not is_a5)
tri_inv = load_tri_inverse() if tri_inv_needed else None

print(f"Workload: N_seq={N_seq} L_seg={L_seg} T={T} D={D} C_PTO={C_PTO} BLOCK_DIM={bd}")
Expand All @@ -593,6 +628,8 @@ def _save_results() -> None:
"timestamp": datetime.now(timezone.utc).isoformat(),
"device": args.device,
"N_seq": N_seq, "L_seg": L_seg, "D": D, "C_pto": C_PTO,
"pto_arch": os.environ.get("MEGAGDN_PTO_ARCH", "a5"),
"pto_only": PTO_ONLY,
"results": all_results,
}
out_path.write_text(json.dumps(meta, indent=2))
Expand Down
73 changes: 73 additions & 0 deletions kernels/pto_a5/PORT_STATUS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# MegaGDN PTO A5 Port Status

## Summary

The A2-oriented PTO kernels were copied to `kernels/pto_a5` and mechanically
ported to compile under A5 / DAV_3510 with:

```bash
export MEGAGDN_PTO_ARCH=a5
source /usr/local/Ascend/cann-9.0.0/set_env.sh
```

The Python build path now selects `kernels/pto_a5` and compiles with
`--cce-aicore-arch=dav-c310`.

## Correctness

Quick single-kernel tests passed on real `npu:0` for `H=16,32,48,64`:

```bash
python3 tests/test_single_kernels.py --device npu:0 --quick --H-list 16,32,48,64
```

Stages covered:

- `chunk_cumsum`
- `scaled_dot_kkt`
- `solve_tril` via A5 torch fallback
- `wy_fast`
- `chunk_h`
- `chunk_o`

## Performance Artifacts

Generated:

- `outputs/data/kernel_bench_a5.json`
- `outputs/data/kernel_bench_a5_comparison.json`
- `outputs/data/kernel_bench_a5_comparison.md`

Command used for the completed large-shape run:

```bash
GDN_BENCH_WARMUP=1 GDN_BENCH_ITERS=3 \
MEGAGDN_PTO_ARCH=a5 \
python3 benchmarks/kernel/bench_gdn_kernels.py \
--device npu:0 \
--n-seq 16 \
--l-seg 16384 \
--H-list 16,32,48,64 \
--stage cumsum,kkt,wy_fast,chunk_h,chunk_o \
--output-json outputs/data/kernel_bench_a5.json
```

The run completed H=16,32,48 and wrote those rows. H=64 timed out during
`wy_fast`, so H=64 is intentionally omitted from the comparison JSON.

## Known Limitations

- `tri_inverse` / PTO `solve_tril` is not fully ported. The A5 copy compiles
after layout fixes but produces NaNs. `solve_tril` uses a torch fallback for
A5 correctness only, and this fallback is not counted as a PTO performance
result.
- `mega_kernel` is not validated on A5 yet because it depends on PTO
`tri_inverse`.
- The current A5 kernels still mostly use the original GM workspace exchange
patterns. A deeper optimization pass should replace the remaining Cube-Vector
GM handoffs in `wy_fast`, `chunk_h`, `chunk_o`, and `mega_kernel` with direct
A5 `TMOV` / `TINSERT` paths.
- After an H=64 timeout, the NPU runtime reported device reopen failures in a
later e2e attempt. A runtime/device reset is recommended before additional
long A5 benchmark runs.

Loading