diff --git a/.agent/skills/translate_cpp2py/SKILL.md b/.agent/skills/translate_cpp2py/SKILL.md new file mode 100644 index 00000000..61f7b8aa --- /dev/null +++ b/.agent/skills/translate_cpp2py/SKILL.md @@ -0,0 +1,176 @@ +--- +name: translate-cpp2py +description: Translate manual PTO-ISA C++ kernels into PTO-DSL Python builders and verification harnesses. Use when converting pto-isa kernel code to ptodsl, generating .pto/.cpp via ptoas, handling manual vs auto sync variants, separating vector vs cube APIs, or adding missing ptodsl API wrappers. +--- + +# Translate PTO-ISA C++ to PTO-DSL + +## Scope + +This skill converts a manually written PTO C++ kernel into: +- a **manual-sync** PTO-DSL Python builder (must mirror source C++ behavior), +- an **auto-sync** PTO-DSL variant (same math/control flow, sync removed), +- generated `.pto` and `.cpp`, +- launcher and runtime correctness test scripts. + +Primary references are under `references/example_translation`. Only consult long compiler/dialect sources when mapping is missing. + +## Required Outputs Per Translation Task + +Produce all of the following unless user asks otherwise: +- Python builder for **manual-sync** kernel. +- Python builder for **auto-sync** kernel. +- Compile scripts: + - manual: `python builder.py > kernel.pto && ptoas kernel.pto -o kernel.cpp` + - auto: `python builder.py > kernel.pto && ptoas --enable-insert-sync kernel.pto -o kernel.cpp` +- `caller.cpp` kernel launcher with correct ABI and launch geometry. +- `run_*.py` load-and-test script to validate numerical correctness. +- `README.md` with minimal usage commands (compile + run + optional bench), following concise style used in `examples/aot/*/README.md`. + +## Non-Negotiable Rules + +1. Input C++ is manual-sync by default. Port to manual-sync Python first. +2. Then create auto-sync variant by removing explicit sync APIs and compiling with `--enable-insert-sync`. +3. Preserve ABI exactly: function name, argument order/types, launch contract. +4. Match section type exactly: vector (`__DAV_VEC__`) vs cube (`__DAV_CUBE__`). +5. Prefer compact Python; preserve semantics, not C++ verbosity. +6. If wrapper is missing in `ptodsl/api`, add it instead of forcing awkward translation. +7. First check if the directory `references/example_translation` is empty or contains too few examples, + If empty, ask for running `scripts/collect_example_translate.py` to generate full Python-C++ mapping examples. + + +## Translation Workflow + +1. **Classify kernel** + - Determine section: vector vs cube. + - Determine sync style: manual vs auto (source C++ is manual). + - Identify core partitioning pattern (block/subblock/batch split). + +2. **Rebuild signature + metadata first** + - Define `meta_data()` with scalar/index/pointer/tensor/subtensor/tile types. + - Use `@to_ir_module(meta_data=meta_data)`. + - Keep argument order identical to C++ kernel ABI. + +3. **Port runtime control flow** + - Use `pto.range`, `pto.if_context`, `pto.cond` for runtime logic. + - Keep all tail guards and truncation branches. + +4. **Port data movement + tile math** + - Build tensors via `pto.as_tensor`. + - Create subviews with `pto.slice_view`. + - Allocate tiles with `pto.alloc_tile`. + - Map load/store/compute ops 1:1 (see mapping rules below). + +5. **Handle synchronization** + - Manual variant: keep explicit event/barrier calls. + - Auto variant: remove manual sync calls, keep op order, compile with insert-sync pass. + +6. **Generate and verify round-trip** + - Emit `.pto`, compile to `.cpp`, and sanity-check structural equivalence. + - Build `.so` with `caller.cpp`. + - Run Python test script against reference (`torch` or equivalent). + +## Sync Modes (Must Explain in Every Task) + +- **Manual sync mode** + - Python uses explicit sync APIs in `ptodsl/api/synchronization.py`. + - Typical APIs: `record_event`, `wait_event`, `record_wait_pair`, `barrier`. + - Compile with plain `ptoas` (no `--enable-insert-sync`). + - Use for direct mirroring of manual C++ or for hand-tuned pipelines. + +- **Auto sync mode** + - Remove explicit sync APIs from Python DSL. + - Compile with `ptoas --enable-insert-sync`. + - Compiler inserts hazard-handling synchronization. + - Use for simpler maintainable variant with same algorithmic behavior. + +Rule of thumb: one kernel variant should use one sync strategy only. + +## Vector vs Cube Section/API Boundaries + +- **Vector kernels** + - Use `with pto.vector_section():` + - Lowers to `#if defined(__DAV_VEC__)`. + - Typical ops: elementwise/reduction/vector dataflow (`tile.add/sub/mul/div/...`). + +- **Cube kernels** + - Use `with pto.cube_section():` + - Lowers to `#if defined(__DAV_CUBE__)`. + - Typical ops: matrix engines (`tile.matmul`, `tile.matmul_acc`, `tile.matmul_bias`). + +- **API surface filtering** + - Vector-only example: `tile.add` in `ptodsl/api/tile.py`. + - Cube-only example: `tile.matmul` in `ptodsl/api/tile.py`. + - Keep agent search narrow: choose section first, then look only at relevant API family. + +## Compact Mapping Rules (Python -> C++) + +1. `@to_ir_module` function -> emitted `__global__ AICORE void ...`. +2. `PtrType(dtype)` -> C++ GM pointer arg type. +3. `TensorType/SubTensorType` + `as_tensor/slice_view` -> `GlobalTensor` objects/views. +4. `TileBufType(memory_space=...)` + `alloc_tile` -> tile declarations in corresponding memory space. +5. `pto.get_block_idx/get_block_num/get_subblock_idx/get_subblock_num` -> runtime core/subcore intrinsics. +6. `s.const/s.index_cast/s.ceil_div/s.select/min` -> scalar arithmetic + branch/select expressions. +7. `pto.range(...)` -> runtime loop in IR/C++. +8. Python `range(...)` -> build-time unroll/metaprogramming. +9. `pto.if_context(...)` / `pto.cond(...)` -> runtime conditional branches. +10. Python `if` -> build-time branch while constructing IR. +11. `pto.load` / `pto.store` -> load/store tile movement ops. +12. `tile.add/sub/mul/div/relu/exp/...` -> corresponding PTO compute intrinsics. +13. `tile.matmul*` family -> cube matmul intrinsics. +14. Multicore distribution usually maps via: + - vector core id = `block_idx * subblock_num + subblock_idx` (vector core is 2x than cube core, `subblock_num` equals 2) + - tiles per core = ceil-div(total tiles, total cores) + - guarded tail processing for final core(s). +15. Dynamic-shape kernels require explicit bound guards before slicing/loading/storing. + +## Runtime Semantics Reminder (Critical) + +PTO-DSL is Python tracing, not AST rewriting: +- Python-native `if/for` executes at build time, similar to C++ compile-time metaprogramming or loop unrolling +- Only `pto.range` and `pto.if_context` represent runtime control flow in generated kernel. + +Never translate runtime C++ control logic into Python-native `if/range` by mistake. + +## Missing API Wrapper Protocol + +If required C++ op has no convenient Python wrapper: + +1. Add thin wrapper in the right module: + - tile/instruction ops -> `ptodsl/api/tile.py` + - general tensor/control helpers -> `ptodsl/api/pto_general.py` + - sync helpers -> `ptodsl/api/synchronization.py` +2. Re-export through `ptodsl/api/pto.py` when needed. +3. Keep wrapper minimal: pass through to MLIR Python binding op with light argument normalization. + +## Escalation Path (Only When Mapping Is Missing) + +Check in order in the `references/external_repo` +1. Clone the `PTOAS` and `pto-isa` repos +2. Check Dialect op definitions: `PTOOps.td` in `PTOAS` repo +3. C++ codegen lowering: `PTOToEmitC.cpp` in `PTOAS` repo +4. ISA semantics: `pto-inst.hpp` in `pto-isa` repo + +If op exists in dialect but not lowered in `PTOToEmitC.cpp`, translation requires PTOAS compiler work (not only DSL wrapper work). +In this case, suggest an issue report to PTOAS project (https://github.com/zhangstevenunity/PTOAS) + +## Round-Trip Verification Checklist + +- [ ] Manual-sync Python version created first and compiles with plain `ptoas`. +- [ ] Auto-sync variant created and compiles with `--enable-insert-sync`. +- [ ] Generated C++ keeps ABI/section/loop/tail semantics. +- [ ] Launcher `caller.cpp` matches kernel symbol and launch parameters. +- [ ] Test script loads `.so`, runs multiple shapes (including tail/non-divisible cases), compares against trusted reference. +- [ ] If multicore kernel: test cases include shapes not multiples of core count. +- [ ] `README.md` documents the exact local commands to compile and run verification. + +## Reference Priority + +Use these first: +- `references/example_translation/**` (primary mapping corpus) +- `references/example_translation/fast_hadamard/**` (manual vs auto sync pair) +- `references/example_translation/batch_matmul/**` (cube kernels) +- `examples/aot/elementwise/add_dynamic_multicore/*` (caller/test/build pattern) +- `examples/aot/matmul_optimization_guide/matmul_optim_guide.md` (sync and runtime-control semantics) + +Consult `references/external_repo/**` only for patterns not covered by examples. diff --git a/.agent/skills/translate_cpp2py/references/example_translation/.gitignore b/.agent/skills/translate_cpp2py/references/example_translation/.gitignore new file mode 100644 index 00000000..72e8ffc0 --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/example_translation/.gitignore @@ -0,0 +1 @@ +* diff --git a/.agent/skills/translate_cpp2py/references/external_repo/README.md b/.agent/skills/translate_cpp2py/references/external_repo/README.md new file mode 100644 index 00000000..8ac51e5e --- /dev/null +++ b/.agent/skills/translate_cpp2py/references/external_repo/README.md @@ -0,0 +1,21 @@ +This directory holds the 3rd-party repos that are used internally by PTO-DSL: +- https://github.com/zhangstevenunity/PTOAS: implements "ptoas" command line tool, the PTO MLIR dialect and its Python bindings, and the InjectSync pass to insert set_flag/wait_flag for "auto-sync" mode. Important files are: + - `PTOAS/include/PTO/IR/PTOOps.td` defines the MLIR PTO dialect + - `PTOAS/python/pto/dialects/pto.py` has low-level Python wrappers of PTO MLIR python binding (more Pythonic wrappers are in pto-dsl package) + - `PTOAS/lib/PTO/Transforms/PTOToEmitC.cpp` the compile pass that converts `*.pto` IR to C++ source code based on PTO-ISA headers. +- https://gitcode.com/cann/pto-isa: header-only library that defined the C++ APIs of PTO-ISA. It is the target API set for the `PTOToEmitC` pass in PTOAS. Important files are: + - `pto-isa/include/pto/common/pto_instr.hpp` the top-level interface + - `pto-isa/include/pto/common/*` common type definitions + - `pto-isa/include/pto/npu/a2a3/*` implementation for current hardware (used in current pto-dsl examples) + - `pto-isa/include/pto/npu/a5/*` implementation for next-generation hardware (not used in current pto-dsl examples) + +Current directory is empty by default, and the repos should be cloned on-the-fly when the agent needs to access extra context. + +For difficult task that needs to look into PTOAS and pto-isa repos, the agent or user can clone them by: + +```bash +git clone https://github.com/zhangstevenunity/PTOAS.git +git clone https://gitcode.com/cann/pto-isa.git +``` + +Remind the user to check if the commit id of PTOAS and pto-isa matches the test environment (usually a pre-built docker image), to avoid mismatch between the context and the real execution. diff --git a/.agent/skills/translate_cpp2py/scripts/collect_example_translate.py b/.agent/skills/translate_cpp2py/scripts/collect_example_translate.py new file mode 100644 index 00000000..ee273c70 --- /dev/null +++ b/.agent/skills/translate_cpp2py/scripts/collect_example_translate.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +"""Collect python->pto->cpp translation examples into a reference directory. + +Usage: + python collect_example_translate.py + python collect_example_translate.py --aot-dir /path/to/examples/aot --out-dir /tmp/example_translation +""" +import json +import argparse +import os +import shutil +import subprocess +from pathlib import Path + + +def unique_dir(base: Path) -> Path: + if not base.exists(): + return base + idx = 2 + while True: + candidate = Path(f"{base}_{idx}") + if not candidate.exists(): + return candidate + idx += 1 + + +REQUIRED_FIELDS = { + "example_dir", + "compile_script", + "py_source", + "py_command", + "ptoas_command", + "pto_file", + "cpp_file", +} +OPTIONAL_FIELDS = {"dependency"} + + +def load_example_list(config_path: Path) -> list[dict[str, object]]: + if not config_path.exists(): + raise FileNotFoundError(f"example config not found: {config_path}") + raw = json.loads(config_path.read_text(encoding="utf-8")) + if not isinstance(raw, list): + raise ValueError("example config root must be a list") + + examples: list[dict[str, object]] = [] + for idx, item in enumerate(raw): + if not isinstance(item, dict): + raise ValueError(f"entry #{idx} must be an object") + missing = REQUIRED_FIELDS - set(item.keys()) + if missing: + raise ValueError(f"entry #{idx} missing fields: {sorted(missing)}") + unknown = set(item.keys()) - REQUIRED_FIELDS - OPTIONAL_FIELDS + if unknown: + raise ValueError(f"entry #{idx} has unknown fields: {sorted(unknown)}") + + normalized: dict[str, str | list[str]] = {} + for key in REQUIRED_FIELDS: + value = item[key] + if not isinstance(value, str) or not value.strip(): + raise ValueError( + f"entry #{idx} field '{key}' must be a non-empty string" + ) + normalized[key] = value + + dependency = item.get("dependency", []) + if not isinstance(dependency, list): + raise ValueError( + f"entry #{idx} field 'dependency' must be a list of strings" + ) + dep_list: list[str] = [] + for dep_idx, dep in enumerate(dependency): + if not isinstance(dep, str) or not dep.strip(): + raise ValueError( + f"entry #{idx} dependency[{dep_idx}] must be a non-empty string" + ) + dep_list.append(dep) + normalized["dependency"] = dep_list + examples.append(normalized) + return examples + + +def parse_args() -> argparse.Namespace: + script_dir = Path(__file__).resolve().parent + default_repo_root = (script_dir / "../../../..").resolve() + parser = argparse.ArgumentParser( + description="Collect python->pto->cpp translation examples." + ) + parser.add_argument( + "--repo-root", + type=Path, + default=default_repo_root, + help="Repository root path (default: script_dir/../../../..).", + ) + parser.add_argument( + "--aot-dir", + type=Path, + default=default_repo_root / "examples/aot", + help="AOT examples directory (default: /examples/aot).", + ) + parser.add_argument( + "--out-dir", + type=Path, + default=(script_dir / "../references/example_translation").resolve(), + help="Output directory (default: script_dir/../references/example_translation).", + ) + parser.add_argument( + "--example-config", + type=Path, + default=script_dir / "example_list.json", + help="Example list json path (default: script_dir/example_list.json).", + ) + return parser.parse_args() + + +def main() -> int: + args = parse_args() + repo_root = args.repo_root.resolve() + aot_dir = args.aot_dir.resolve() + out_dir = args.out_dir.resolve() + example_config = args.example_config.resolve() + example_list = load_example_list(example_config) + + if not aot_dir.is_dir(): + raise FileNotFoundError(f"AOT examples directory not found: {aot_dir}") + + if out_dir.exists(): + shutil.rmtree(out_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + copied = 0 + failed = 0 + found = len(example_list) + results: list[dict[str, str]] = [] + + def display_path(path: Path) -> str: + try: + return str(path.relative_to(repo_root)) + except ValueError: + # In CI we may intentionally write outside repo root (e.g. /tmp). + return str(path) + + for idx, example in enumerate(example_list, start=1): + rel_dir = Path(str(example["example_dir"])) + example_dir = aot_dir / rel_dir + py_rel = Path(str(example["py_source"])) + py_source = example_dir / py_rel + py_cmd = str(example["py_command"]) + ptoas_cmd = str(example["ptoas_command"]) + example_name = f"{example['example_dir']}:{example['pto_file']}" + progress_name = py_rel.stem + dependencies = example.get("dependency", []) + print(f"[{idx}/{found}] collecting {rel_dir}/{progress_name}") + + if not py_source.exists(): + failed += 1 + results.append( + { + "name": example_name, + "status": "FAIL", + "reason": f"python source does not exist: {py_source}", + } + ) + continue + + dst = unique_dir(out_dir / rel_dir / Path(str(example["pto_file"])).stem) + dst.mkdir(parents=True, exist_ok=True) + + py_dst = dst / py_rel + py_dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(py_source, py_dst) + dep_copy_failed = False + for dep in dependencies: + dep_src = example_dir / dep + if not dep_src.exists(): + failed += 1 + results.append( + { + "name": example_name, + "status": "FAIL", + "reason": f"dependency does not exist: {dep_src}", + } + ) + dep_copy_failed = True + break + dep_dst = dst / dep + dep_dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(dep_src, dep_dst) + if dep_copy_failed: + continue + + run_env = os.environ.copy() + + py_run = subprocess.run( + py_cmd, + shell=True, + cwd=dst, + env=run_env, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + if py_run.returncode != 0: + failed += 1 + output = (py_run.stdout or "").strip() + results.append( + { + "name": example_name, + "status": "FAIL", + "reason": f"python command failed: {py_cmd}" + + (f" | {output}" if output else ""), + } + ) + continue + + ptoas_run = subprocess.run( + ptoas_cmd, + shell=True, + cwd=dst, + env=run_env, + executable="/bin/bash", + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + if ptoas_run.returncode != 0: + failed += 1 + output = (ptoas_run.stdout or "").strip() + results.append( + { + "name": example_name, + "status": "FAIL", + "reason": f"ptoas command failed: {ptoas_cmd}" + + (f" | {output}" if output else ""), + } + ) + continue + + pto_dst = dst / str(example["pto_file"]) + cpp_dst = dst / str(example["cpp_file"]) + if not (pto_dst.exists() and cpp_dst.exists()): + failed += 1 + results.append( + { + "name": example_name, + "status": "FAIL", + "reason": ( + "expected outputs missing after compile: " + f"{example['pto_file']}, {example['cpp_file']}" + ), + } + ) + continue + + commands = [ + "#!/usr/bin/env bash", + "set -e", + py_cmd, + ptoas_cmd, + "", + ] + (dst / "compile.sh").write_text("\n".join(commands), encoding="utf-8") + + copied += 1 + results.append( + { + "name": example_name, + "status": "OK", + "reason": f"collected to {display_path(dst)}", + } + ) + + print(f"Discovered {found} python->pto candidates under {aot_dir}") + for item in results: + print(f"[{item['status']}] {item['name']} - {item['reason']}") + print(f"Collected {copied} translation examples into {out_dir}") + print(f"Failed to collect {failed} examples") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/.agent/skills/translate_cpp2py/scripts/example_list.json b/.agent/skills/translate_cpp2py/scripts/example_list.json new file mode 100644 index 00000000..453406ff --- /dev/null +++ b/.agent/skills/translate_cpp2py/scripts/example_list.json @@ -0,0 +1,168 @@ +[ + { + "example_dir": "activations/geglu_dynamic_multicore", + "compile_script": "compile.sh", + "py_source": "geglu_builder.py", + "py_command": "python ./geglu_builder.py > ./geglu.pto", + "ptoas_command": "ptoas --enable-insert-sync ./geglu.pto -o ./geglu.cpp", + "pto_file": "geglu.pto", + "cpp_file": "geglu.cpp" + }, + { + "example_dir": "activations/relu_dynamic_multicore", + "compile_script": "compile.sh", + "py_source": "relu_builder.py", + "py_command": "python relu_builder.py > ./relu.pto", + "ptoas_command": "ptoas --enable-insert-sync ./relu.pto > generated_relu.cpp", + "pto_file": "relu.pto", + "cpp_file": "generated_relu.cpp" + }, + { + "example_dir": "batch_matmul/matmul_dynbatch_multicore", + "compile_script": "compile.sh", + "py_source": "matmul_builder.py", + "py_command": "python ./matmul_builder.py > matmul.pto", + "ptoas_command": "ptoas matmul.pto -o matmul.cpp", + "pto_file": "matmul.pto", + "cpp_file": "matmul.cpp" + }, + { + "example_dir": "batch_matmul/matmul_dynbatch_multicore_opt", + "compile_script": "compile.sh", + "py_source": "matmul_builder.py", + "py_command": "python ./matmul_builder.py > matmul.pto", + "ptoas_command": "ptoas matmul.pto -o matmul.cpp", + "pto_file": "matmul.pto", + "cpp_file": "matmul.cpp" + }, + { + "example_dir": "elementwise/add_dynamic_multicore", + "compile_script": "compile.sh", + "py_source": "add_builder.py", + "py_command": "python ./add_builder.py > ./add.pto", + "ptoas_command": "ptoas --enable-insert-sync ./add.pto -o ./add.cpp", + "pto_file": "add.pto", + "cpp_file": "add.cpp" + }, + { + "example_dir": "elementwise/add_dynamic_multicore", + "compile_script": "compile_double.sh", + "py_source": "add_double_builder.py", + "py_command": "python ./add_double_builder.py > ./add_double.pto", + "ptoas_command": "ptoas --enable-insert-sync ./add_double.pto -o ./add_double.cpp", + "pto_file": "add_double.pto", + "cpp_file": "add_double.cpp" + }, + { + "example_dir": "fast_hadamard", + "compile_script": "compile.sh", + "py_source": "hadamard_builder.py", + "py_command": "python ./hadamard_builder.py > ./hadamard_auto_sync.pto", + "ptoas_command": "ptoas --enable-insert-sync ./hadamard_auto_sync.pto -o ./hadamard_auto_sync.cpp", + "pto_file": "hadamard_auto_sync.pto", + "cpp_file": "hadamard_auto_sync.cpp" + }, + { + "example_dir": "fast_hadamard", + "compile_script": "compile.sh", + "py_source": "hadamard_builder.py", + "py_command": "python ./hadamard_builder.py --manual-sync > ./hadamard_manual_sync.pto", + "ptoas_command": "ptoas ./hadamard_manual_sync.pto -o ./hadamard_manual_sync.cpp", + "pto_file": "hadamard_manual_sync.pto", + "cpp_file": "hadamard_manual_sync.cpp" + }, + { + "example_dir": "fast_inverse/basic_dense", + "compile_script": "compile.sh", + "py_source": "inverse_builder.py", + "py_command": "python ./inverse_builder.py --matrix-size 64 > ./inverse_basic_dense_64.pto", + "ptoas_command": "ptoas --enable-insert-sync ./inverse_basic_dense_64.pto -o ./inverse_basic_dense_64.cpp", + "pto_file": "inverse_basic_dense_64.pto", + "cpp_file": "inverse_basic_dense_64.cpp" + }, + { + "example_dir": "fast_inverse/block_inversion", + "compile_script": "compile.sh", + "py_source": "inverse_builder.py", + "py_command": "python ./inverse_builder.py --matrix-size 64 > ./inverse_block_inversion_64.pto", + "ptoas_command": "ptoas --enable-insert-sync ./inverse_block_inversion_64.pto -o ./inverse_block_inversion_64.cpp", + "pto_file": "inverse_block_inversion_64.pto", + "cpp_file": "inverse_block_inversion_64.cpp" + }, + { + "example_dir": "fast_inverse/basic_dense", + "compile_script": "compile.sh", + "py_source": "inverse_builder.py", + "py_command": "python ./inverse_builder.py --matrix-size 128 > ./inverse_basic_dense_128.pto", + "ptoas_command": "ptoas --enable-insert-sync ./inverse_basic_dense_128.pto -o ./inverse_basic_dense_128.cpp", + "pto_file": "inverse_basic_dense_128.pto", + "cpp_file": "inverse_basic_dense_128.cpp" + }, + { + "example_dir": "matmul_optimization_guide", + "compile_script": "compile.sh", + "py_source": "step1_baseline.py", + "py_command": "python ./step1_baseline.py > ./step1_baseline.pto", + "ptoas_command": "ptoas --enable-insert-sync ./step1_baseline.pto -o ./step1_baseline.cpp", + "pto_file": "step1_baseline.pto", + "cpp_file": "step1_baseline.cpp", + "dependency": ["common_utils.py"] + }, + { + "example_dir": "matmul_optimization_guide", + "compile_script": "compile.sh", + "py_source": "step2_doublebuffer.py", + "py_command": "python ./step2_doublebuffer.py > ./step2_doublebuffer.pto", + "ptoas_command": "ptoas --enable-insert-sync ./step2_doublebuffer.pto -o ./step2_doublebuffer.cpp", + "pto_file": "step2_doublebuffer.pto", + "cpp_file": "step2_doublebuffer.cpp", + "dependency": ["common_utils.py"] + }, + { + "example_dir": "matmul_optimization_guide", + "compile_script": "compile.sh", + "py_source": "step3_swizzle.py", + "py_command": "python ./step3_swizzle.py > ./step3_swizzle.pto", + "ptoas_command": "ptoas --enable-insert-sync ./step3_swizzle.pto -o ./step3_swizzle.cpp", + "pto_file": "step3_swizzle.pto", + "cpp_file": "step3_swizzle.cpp", + "dependency": ["common_utils.py"] + }, + { + "example_dir": "matmul_optimization_guide", + "compile_script": "compile.sh", + "py_source": "step4_manual_pipelining.py", + "py_command": "python ./step4_manual_pipelining.py > ./step4_manual_pipelining.pto", + "ptoas_command": "ptoas ./step4_manual_pipelining.pto -o ./step4_manual_pipelining.cpp", + "pto_file": "step4_manual_pipelining.pto", + "cpp_file": "step4_manual_pipelining.cpp", + "dependency": ["common_utils.py"] + }, + { + "example_dir": "matmul_optimization_guide/experimental", + "compile_script": "compile.sh", + "py_source": "matmul_builder.py", + "py_command": "python ./matmul_builder.py > matmul.pto", + "ptoas_command": "ptoas matmul.pto -o matmul.cpp", + "pto_file": "matmul.pto", + "cpp_file": "matmul.cpp" + }, + { + "example_dir": "simple_static/add_static_multicore", + "compile_script": "compile.sh", + "py_source": "add_builder.py", + "py_command": "python ./add_builder.py > ./add.pto", + "ptoas_command": "ptoas --enable-insert-sync ./add.pto -o ./add.cpp", + "pto_file": "add.pto", + "cpp_file": "add.cpp" + }, + { + "example_dir": "simple_static/matmul_static_singlecore", + "compile_script": "compile.sh", + "py_source": "matmul_builder.py", + "py_command": "python ./matmul_builder.py > matmul.pto", + "ptoas_command": "ptoas --enable-insert-sync matmul.pto -o matmul.cpp", + "pto_file": "matmul.pto", + "cpp_file": "matmul.cpp" + } +] diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0b732ac8..dd33e542 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,23 +8,47 @@ on: workflow_dispatch: jobs: + pre-commit: + name: pre-commit + runs-on: ubuntu-24.04 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Run pre-commit checks + run: | + python -m pip install --upgrade pip + python -m pip install pre-commit + pre-commit run --all-files + test: + name: test (${{ matrix.arch }}, ${{ matrix.install-mode }}) strategy: fail-fast: false matrix: include: - arch: x86_64 runs-on: ubuntu-24.04 + install-mode: standard + - arch: x86_64 + runs-on: ubuntu-24.04 + install-mode: editable - arch: aarch64 runs-on: ubuntu-24.04-arm + install-mode: standard + - arch: aarch64 + runs-on: ubuntu-24.04-arm + install-mode: editable runs-on: ${{ matrix.runs-on }} container: image: quay.io/ascend/cann:8.5.0-910b-ubuntu22.04-py3.11 env: - RELEASE_REPO: huawei-csl/PTOAS - RELEASE_TAG: 20260303 + RELEASE_REPO: zhangstevenunity/PTOAS + RELEASE_VER: 0.9 + RELEASE_TAG: v0.9 CLI_DIR: /installers/ptoas-cli PTOISA_COMMIT: 672ee54cb8905bb9f9abbe80ec26ed2054b7a0cc @@ -43,7 +67,7 @@ jobs: - name: Install ptoas wheel run: | - WHEEL_NAME=ptoas-0.1.1-cp311-none-manylinux_2_34_${{ matrix.arch }}.whl + WHEEL_NAME=ptoas-${RELEASE_VER}-cp311-none-manylinux_2_34_${{ matrix.arch }}.whl wget https://github.com/${RELEASE_REPO}/releases/download/${RELEASE_TAG}/${WHEEL_NAME} pip install ./${WHEEL_NAME} python -c "import mlir.ir; from mlir.dialects import pto" @@ -63,8 +87,13 @@ jobs: git clone https://github.com/PTO-ISA/pto-isa.git /sources/pto-isa cd /sources/pto-isa && git checkout ${PTOISA_COMMIT} - - name: Install ptodsl - run: pip install -e ./ptodsl + - name: Install ptodsl (${{ matrix.install-mode }}) + run: | + if [ "${{ matrix.install-mode }}" = "standard" ]; then + pip install . + else + pip install -e . + fi - name: Run frontend tests run: pytest -v ./tests/frontend @@ -75,3 +104,9 @@ jobs: pytest -v -m "not require_npu" ./tests/npu env: TORCH_DEVICE_BACKEND_AUTOLOAD: "0" + + - name: Run example translation collection check + run: | + python ./.agent/skills/translate_cpp2py/scripts/collect_example_translate.py \ + --repo-root . \ + --out-dir /tmp/example_translation diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..91bda22c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + - id: check-yaml + - id: check-json + - id: check-merge-conflict + - id: check-added-large-files + - id: check-toml + - id: detect-private-key + - id: check-ast +- repo: https://github.com/psf/black + rev: 25.12.0 + hooks: + - id: black diff --git a/README.md b/README.md index 0229c240..bfffce44 100644 --- a/README.md +++ b/README.md @@ -11,17 +11,26 @@ PTO-DSL provides a programming abstraction similar to [cuTile](https://docs.nvid - Easily interface with [torch-npu](https://gitcode.com/ascend/pytorch) - Lightweight, open-source compiler stack using [PTO Assembler](https://github.com/zhangstevenunity/PTOAS) -**Compare to other kernel programming frameworks** (e.g. [tilelang-ascend](https://github.com/tile-ai/tilelang-ascend), [triton-ascend](https://gitcode.com/Ascend/triton-ascend), and [catlass](https://gitcode.com/cann/catlass)): -- PTO-DSL aims for **low-level, explicit, NPU-native primitives** that can match the performance of **programming in [hardware intrinsics](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/API/cceintrinsicapi/cceapi_0001.html)**, filling the gap of a [CuteDSL](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/overview.html)-like low-level Python programming for NPU. +## Installation -## Environment +See [docker/README.md](./docker/README.md) for full reproducible dependencies on NPU. -See [docker](./docker) +Then, install this lightweight DSL package itself: -## Installation +```bash +# install latest commit +pip install git+https://github.com/huawei-csl/pto-dsl.git + +# or stable tag +pip install git+https://github.com/huawei-csl/pto-dsl.git@0.1.0 +``` + +For in-place development: ```bash -pip install -e ./ptodsl +git clone https://github.com/huawei-csl/pto-dsl.git +cd pto-dsl +pip install -e . ``` ## Usage @@ -31,3 +40,11 @@ See [examples](./examples) and [tests](./tests) ## Contribute See [contribute_guide.md](./contribute_guide.md) + +## Compare to other frameworks + +PTO-DSL aims for **low-level, explicit, NPU-native primitives** that can match the performance of **programming in [hardware intrinsics](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/API/cceintrinsicapi/cceapi_0001.html)**. Compared to other (also very good) kernel programming frameworks, it has a bit different scope by design: +- vs [tilelang-ascend](https://github.com/tile-ai/tilelang-ascend): tilelang can also [use PTO-ISA as codegen backend](https://github.com/tile-ai/tilelang-ascend/blob/76553755da078479a7f60cce9c5f0e9a24d0008b/src/target/codegen_ascend_pto.cc). PTO-DSL intentionally exposes lower-level control, for example L2 swizzling is one-liner `T.use_swizzle` in tilelang, but is a user-defined custom function in PTO-DSL -- see this [matmul optimization example](examples/aot/matmul_optimization_guide/matmul_optim_guide.md). Once PTO-DSL is more stabilized, it might serve as a component like the [CuteDSL backend for tilelang](https://github.com/tile-ai/tilelang/blob/v0.1.8/src/target/codegen_cutedsl.cc). +- vs [triton-ascend](https://gitcode.com/Ascend/triton-ascend): Both frameworks automate software pipelining based on some MLIR dialects for NPU. PTO-DSL exposes more NPU-native memory hierarchy such as `L0`/`L1`/`UB`. Also, `pto.load`/`pto.store` always maps to native efficient DMA instructions, while `tl.load`/`tl.store` tries to do GPU-style memory coalescing. +- vs [Catlass](https://gitcode.com/cann/catlass): Catlass provides expert-optimized template collections, while PTO-DSL is more like the [CuteDSL](https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/overview.html) layer of Cutlass, offering explicit low-level primitives. +- vs [PyPTO](https://gitcode.com/cann/pypto): PyPTO is a full [MPMD](https://en.wikipedia.org/wiki/Flynn%27s_taxonomy#Multiple_programs,_multiple_data_streams_(MPMD)) dynamic runtime stack, which also [uses PTO-ISA as lowest-level primitive](https://gitcode.com/cann/pypto/tree/r0.1.1/framework/src/interface/tileop). PyPTO's Tensor API abstraction is closer to PyTorch/JAX level, while a PTO-DSL kernel is still [SPMD](https://en.wikipedia.org/wiki/Single_program,_multiple_data) and is closer to CuTile/CuteDSL level. diff --git a/docker/Dockerfile b/docker/Dockerfile index e98a0a79..7c7dea8a 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -13,15 +13,26 @@ RUN pip install --no-cache-dir torch==2.9.0 --index-url https://download.pytorch # extra util RUN pip install --no-cache-dir \ pytest pybind11 nanobind setuptools wheel \ - ipython jupyterlab + ipython jupyterlab matplotlib pandas + +# certain operations need latest isa header, not CANN 8.5.0 default +# header on 2026/03/16 +ARG PTOISA_COMMIT=313817be696792a4e16a7ea5994ec98e34391613 +WORKDIR /sources +RUN git clone https://gitcode.com/cann/pto-isa.git \ + && cd pto-isa && git checkout $PTOISA_COMMIT # cache above layers unrelated to ptoas version change +# change this ununsed arg if need to force rebuild later lines +ARG CACHE_BURST=1 + # ARG ARCH=x86_64 ARG ARCH=aarch64 -ARG RELEASE_REPO=huawei-csl/PTOAS -ARG RELEASE_TAG=20260228 -ARG WHEEL_NAME=ptoas-0.1.1-cp311-none-manylinux_2_34_${ARCH}.whl +ARG RELEASE_REPO=zhangstevenunity/PTOAS +ARG RELEASE_VER=0.9 +ARG RELEASE_TAG=v${RELEASE_VER} +ARG WHEEL_NAME=ptoas-${RELEASE_VER}-cp311-none-manylinux_2_34_${ARCH}.whl ARG CLI_TAR_NAME=ptoas-bin-${ARCH}.tar.gz WORKDIR /installers/ @@ -50,7 +61,6 @@ RUN ptoas ./tmatmulk.pto -o ./tmatmulk.cpp RUN python ./abs.py > ./abs.pto RUN ptoas --enable-insert-sync ./abs.pto -o ./abs.cpp - # certain operations need latest isa header, not CANN 8.5.0 default # header on 2026/02/14 ARG PTOISA_COMMIT=672ee54cb8905bb9f9abbe80ec26ed2054b7a0cc diff --git a/docker/README.md b/docker/README.md index 55bd125c..93febcd0 100644 --- a/docker/README.md +++ b/docker/README.md @@ -1,19 +1,21 @@ -Usage: +Recommend using [Ascend Docker Runtime](https://gitcode.com/Ascend/mind-cluster/tree/master/component/ascend-docker-runtime) for a reproducible env. Install it on top of normal Docker, using `Ascend-docker-runtime*.run` files in the [Release page](https://gitcode.com/Ascend/mind-cluster/releases). + +Then, build and run docker image: ```bash -RELEASE_TAG=20260303 +RELEASE_VER=0.9 sudo docker build \ - --build-arg RELEASE_TAG=$RELEASE_TAG \ - . -t pto_dsl:$RELEASE_TAG + --build-arg RELEASE_VER=$RELEASE_VER \ + . -t pto_dsl:$RELEASE_VER # for specific arch (x86_64 vs aarch64) sudo docker build \ --build-arg ARCH=x86_64 \ - --build-arg RELEASE_TAG=$RELEASE_TAG \ - . -t pto_dsl:$RELEASE_TAG + --build-arg RELEASE_VER=$RELEASE_VER \ + . -t pto_dsl:$RELEASE_VER # to test compile-only -sudo docker run --rm -it pto_dsl:$RELEASE_TAG /bin/bash +sudo docker run --rm -it pto_dsl:$RELEASE_VER /bin/bash # to test on-device execution sudo docker run --rm -it --ipc=host --privileged \ @@ -28,7 +30,7 @@ sudo docker run --rm -it --ipc=host --privileged \ -v /usr/local/Ascend/driver:/usr/local/Ascend/driver:ro \ -v /etc/ascend_install.info:/etc/ascend_install.info:ro \ -v $HOME:/mounted_home -w /mounted_home \ - pto_dsl:$RELEASE_TAG /bin/bash + pto_dsl:$RELEASE_VER /bin/bash ``` ## Appendix: NPU driver diff --git a/examples/aot/activations/geglu_dynamic_multicore/.gitignore b/examples/aot/activations/geglu_dynamic_multicore/.gitignore new file mode 100644 index 00000000..99be97c5 --- /dev/null +++ b/examples/aot/activations/geglu_dynamic_multicore/.gitignore @@ -0,0 +1,3 @@ +geglu.pto +geglu.cpp +geglu_lib.so diff --git a/examples/aot/activations/geglu_dynamic_multicore/README.md b/examples/aot/activations/geglu_dynamic_multicore/README.md new file mode 100644 index 00000000..374bb9cc --- /dev/null +++ b/examples/aot/activations/geglu_dynamic_multicore/README.md @@ -0,0 +1,7 @@ +Usage: + +```bash +bash ./compile.sh +python ./run_geglu.py +python ./bench_geglu.py +``` diff --git a/examples/aot/activations/geglu_dynamic_multicore/bench_geglu.py b/examples/aot/activations/geglu_dynamic_multicore/bench_geglu.py new file mode 100644 index 00000000..2c36e5a7 --- /dev/null +++ b/examples/aot/activations/geglu_dynamic_multicore/bench_geglu.py @@ -0,0 +1,128 @@ +import argparse +import ctypes + +import torch +import torch.nn.functional as F +import torch_npu # noqa: F401 + +from ptodsl.test_util import get_test_device + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path, block_dim=24): + lib = ctypes.CDLL(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # a + ctypes.c_void_p, # b + ctypes.c_void_p, # c (output) + ctypes.c_uint32, # batch + ctypes.c_uint32, # n_cols + ] + lib.call_kernel.restype = None + + def geglu_func(a, b, c, batch, n_cols, stream_ptr=None): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(a), + torch_to_ctypes(b), + torch_to_ctypes(c), + batch, + n_cols, + ) + + return geglu_func + + +def bench_geglu( + geglu_func, a, b, c, kernel_name="geglu_func", warmup_iters=5, benchmark_iters=50 +): + batch, n_cols = a.shape + # reads a and b, writes c + io_bytes = a.numel() * a.element_size() * 3 + # Overwrite a large buffer between launches to reduce L2 cache reuse. + cache = torch.empty((256 * 1024 * 1024,), dtype=torch.int8, device=a.device) + + def time_op(fn): + for _ in range(warmup_iters): + fn() + torch.npu.synchronize() + + mixed_start = torch.npu.Event(enable_timing=True) + mixed_end = torch.npu.Event(enable_timing=True) + cache_start = torch.npu.Event(enable_timing=True) + cache_end = torch.npu.Event(enable_timing=True) + + mixed_start.record() + for _ in range(benchmark_iters): + cache.zero_() + fn() + mixed_end.record() + torch.npu.synchronize() + + cache_start.record() + for _ in range(benchmark_iters): + cache.zero_() + cache_end.record() + torch.npu.synchronize() + + mixed_total_ms = mixed_start.elapsed_time(mixed_end) + cache_total_ms = cache_start.elapsed_time(cache_end) + kernel_total_ms = max(mixed_total_ms - cache_total_ms, 0.0) + return kernel_total_ms / benchmark_iters + + custom_ms = time_op(lambda: geglu_func(a, b, c, batch, n_cols)) + torch_ms = time_op(lambda: torch.mul(F.gelu(a, approximate="tanh"), b)) + + custom_bw_gbs = (io_bytes / (custom_ms / 1e3)) / 1e9 + torch_bw_gbs = (io_bytes / (torch_ms / 1e3)) / 1e9 + + print( + f"{kernel_name}: {custom_ms:.3f} ms, " + f"effective bandwidth: {custom_bw_gbs:.3f} GB/s " + f"(IO={io_bytes / 1e6:.2f} MB)" + ) + print( + f"torch gelu*b: {torch_ms:.3f} ms, " + f"effective bandwidth: {torch_bw_gbs:.3f} GB/s " + f"(IO={io_bytes / 1e6:.2f} MB)" + ) + + +def run_bench(lib_path, block_dim=24, batch=1024, n_cols=8192): + device = get_test_device() + torch.npu.set_device(device) + + geglu_func = load_lib(lib_path, block_dim=block_dim) + + torch.manual_seed(0) + dtype = torch.float16 + a = torch.randn(batch, n_cols, device=device, dtype=dtype).clamp(-4, 4) + b = torch.randn(batch, n_cols, device=device, dtype=dtype) + c = torch.empty(batch, n_cols, device=device, dtype=dtype) + + geglu_func(a, b, c, batch, n_cols) + torch.npu.synchronize() + + a_f32 = a.float() + ref = (0.5 * a_f32 * (1.0 + torch.tanh(a_f32))).to(dtype) * b + torch.testing.assert_close(c, ref, rtol=1e-2, atol=1e-2) + + bench_geglu(geglu_func, a, b, c, kernel_name=f"geglu ({lib_path})") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--lib", default="./geglu_lib.so") + parser.add_argument("--block-dim", type=int, default=24) + parser.add_argument("--batch", type=int, default=1024) + parser.add_argument("--n-cols", type=int, default=8192) + args = parser.parse_args() + run_bench(args.lib, block_dim=args.block_dim, batch=args.batch, n_cols=args.n_cols) diff --git a/examples/aot/activations/geglu_dynamic_multicore/caller.cpp b/examples/aot/activations/geglu_dynamic_multicore/caller.cpp new file mode 100644 index 00000000..85351fd4 --- /dev/null +++ b/examples/aot/activations/geglu_dynamic_multicore/caller.cpp @@ -0,0 +1,26 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "geglu.cpp" +#endif +#include KERNEL_CPP + +#ifndef NUM_CORES +#define NUM_CORES 24 +#endif + +extern "C" void call_kernel( + uint32_t blockDim, + void *stream, + uint8_t *a, + uint8_t *b, + uint8_t *c, + uint32_t batch, + uint32_t n_cols) +{ + uint32_t launch_blocks = blockDim > 0 ? blockDim : NUM_CORES; + _kernel<<>>( + reinterpret_cast(a), + reinterpret_cast(b), + reinterpret_cast(c), + static_cast(batch), + static_cast(n_cols)); +} diff --git a/examples/aot/activations/geglu_dynamic_multicore/compile.sh b/examples/aot/activations/geglu_dynamic_multicore/compile.sh new file mode 100755 index 00000000..9da3faa7 --- /dev/null +++ b/examples/aot/activations/geglu_dynamic_multicore/compile.sh @@ -0,0 +1,22 @@ +set -e + +rm -f geglu.pto geglu.cpp geglu_lib.so + +python ./geglu_builder.py > ./geglu.pto +ptoas --enable-insert-sync ./geglu.pto -o ./geglu.cpp + +bisheng \ + -I${ASCEND_TOOLKIT_HOME}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + -DKERNEL_CPP="\"geglu.cpp\"" \ + ./caller.cpp \ + -o ./geglu_lib.so diff --git a/examples/aot/activations/geglu_dynamic_multicore/geglu_builder.py b/examples/aot/activations/geglu_dynamic_multicore/geglu_builder.py new file mode 100644 index 00000000..7d1e88b3 --- /dev/null +++ b/examples/aot/activations/geglu_dynamic_multicore/geglu_builder.py @@ -0,0 +1,179 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +# 32 KB of UB / sizeof(fp16) = 16384 elements per tile +ELEMENTS_PER_TILE = 32 * 1024 // 2 + + +def meta_data(): + dtype = pto.float16 + ptr_type = pto.PtrType(dtype) + index_dtype = pto.int32 + + tensor_type = pto.TensorType(rank=1, dtype=dtype) + subtensor_type = pto.SubTensorType(shape=[1, ELEMENTS_PER_TILE], dtype=dtype) + + tile_cfg = pto.TileBufConfig() + tile_type = pto.TileBufType( + shape=[1, ELEMENTS_PER_TILE], + valid_shape=[1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + + return { + "ptr_type": ptr_type, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_type": subtensor_type, + "tile_type": tile_type, + } + + +def build_geglu(fn_name="geglu_fp16"): + """ + Build a dynamic-batch GEGLU kernel in PTO DSL. + + Computes c = gelu_approx(a) * b, where: + gelu_approx(a) = 0.5 * a * (1 + tanh(a)) + tanh(a) = (exp(2a) - 1) / (exp(2a) + 1) + + Constants (1.0, 2.0) are derived from the input tile itself using + the identity exp(a - a) = exp(0) = 1.0, which avoids the need for + scalar-tile broadcast operations not available in PTO DSL. + + UB tile budget (fp16, 5 tiles × 32 KB = 160 KB < 192 KB): + tb_a : input row a + tb_b : input row b + tb_ones : constant 1.0 (recomputed each row via exp(a-a)) + tb_tmp1 : intermediate / final output + tb_tmp2 : intermediate + + Kernel args: + a_ptr : fp16[batch * n_cols] -- gating input + b_ptr : fp16[batch * n_cols] -- linear input + c_ptr : fp16[batch * n_cols] -- output + batch : int32 -- number of rows + n_cols : int32 -- elements per row; must be <= 16384 + """ + + @to_ir_module(meta_data=meta_data) + def _kernel( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + batch_i32: "index_dtype", + n_cols_i32: "index_dtype", + ) -> None: + c0 = const(0) + c1 = const(1) + c_tile = const(ELEMENTS_PER_TILE) + + batch = s.index_cast(batch_i32) + n_cols = s.index_cast(n_cols_i32) + + with pto.vector_section(): + # Guard: n_cols must be in (0, ELEMENTS_PER_TILE]. + + with pto.if_context(n_cols > c0): + with pto.if_context(c_tile >= n_cols): + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast( + num_blocks * sub_bnum + ) # number of vector cores + + # Distribute rows across cores (row-level parallelism). + rows_per_core = s.ceil_div(batch, num_cores) + row_start = vid * rows_per_core + row_end = s.min_u(row_start + rows_per_core, batch) + num_rows = row_end - row_start + + total_elems = batch * n_cols + tv_a = pto.as_tensor( + tensor_type, ptr=a_ptr, shape=[total_elems], strides=[c1] + ) + tv_b = pto.as_tensor( + tensor_type, ptr=b_ptr, shape=[total_elems], strides=[c1] + ) + tv_c = pto.as_tensor( + tensor_type, ptr=c_ptr, shape=[total_elems], strides=[c1] + ) + + with pto.if_context(num_rows > c0): + # Allocate 5 UB tiles (160 KB total, well under 192 KB UB). + tb_a = pto.alloc_tile(tile_type, valid_col=n_cols) + tb_b = pto.alloc_tile(tile_type, valid_col=n_cols) + tb_ones = pto.alloc_tile(tile_type, valid_col=n_cols) + tb_tmp1 = pto.alloc_tile(tile_type, valid_col=n_cols) + tb_tmp2 = pto.alloc_tile(tile_type, valid_col=n_cols) + + for row_i in pto.range(c0, num_rows, c1): + gm_offset = (row_start + row_i) * n_cols + + sv_a = pto.slice_view( + subtensor_type, + source=tv_a, + offsets=[gm_offset], + sizes=[n_cols], + ) + sv_b = pto.slice_view( + subtensor_type, + source=tv_b, + offsets=[gm_offset], + sizes=[n_cols], + ) + sv_c = pto.slice_view( + subtensor_type, + source=tv_c, + offsets=[gm_offset], + sizes=[n_cols], + ) + + pto.load(sv_a, tb_a) + pto.load(sv_b, tb_b) + + # Derive constants from data (no scalar-tile broadcast needed): + # a - a = 0 => exp(0) = 1.0 + tile.sub(tb_a, tb_a, tb_tmp2) # tmp2 = 0.0 + tile.exp(tb_tmp2, tb_ones) # ones = 1.0 + + # tanh(a) = (exp(2a) - 1) / (exp(2a) + 1) + tile.add(tb_a, tb_a, tb_tmp1) # tmp1 = 2a + tile.exp(tb_tmp1, tb_tmp1) # tmp1 = exp(2a) + tile.sub(tb_tmp1, tb_ones, tb_tmp2) # tmp2 = exp(2a) - 1 + tile.add(tb_tmp1, tb_ones, tb_tmp1) # tmp1 = exp(2a) + 1 + tile.div(tb_tmp2, tb_tmp1, tb_tmp2) # tmp2 = tanh(a) + + # gelu_approx(a) = a * (1 + tanh(a)) / 2 + tile.add(tb_ones, tb_tmp2, tb_tmp1) # tmp1 = 1 + tanh(a) + tile.mul(tb_a, tb_tmp1, tb_tmp1) # tmp1 = a * (1 + tanh(a)) + tile.add(tb_ones, tb_ones, tb_tmp2) # tmp2 = 2.0 + tile.div(tb_tmp1, tb_tmp2, tb_tmp1) # tmp1 = gelu_approx(a) + + # GEGLU: c = gelu_approx(a) * b + tile.mul(tb_tmp1, tb_b, tb_tmp1) # tmp1 = c + pto.store(tb_tmp1, sv_c) + + _ = fn_name + return _kernel + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--fn-name", + default="geglu_fp16", + help="Generated kernel function name.", + ) + args = parser.parse_args() + print(build_geglu(fn_name=args.fn_name)) diff --git a/examples/aot/activations/geglu_dynamic_multicore/run_geglu.py b/examples/aot/activations/geglu_dynamic_multicore/run_geglu.py new file mode 100644 index 00000000..a180206c --- /dev/null +++ b/examples/aot/activations/geglu_dynamic_multicore/run_geglu.py @@ -0,0 +1,121 @@ +import argparse +import ctypes + +import torch +import torch_npu # noqa: F401 + +from ptodsl.test_util import get_test_device + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path, block_dim=24): + lib = ctypes.CDLL(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # a + ctypes.c_void_p, # b + ctypes.c_void_p, # c (output) + ctypes.c_uint32, # batch + ctypes.c_uint32, # n_cols + ] + lib.call_kernel.restype = None + + def geglu_func(a, b, c, batch, n_cols, block_dim=block_dim, stream_ptr=None): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(a), + torch_to_ctypes(b), + torch_to_ctypes(c), + batch, + n_cols, + ) + + return geglu_func + + +def geglu_ref(a, b): + """Reference GEGLU matching the PTO kernel. + + Computes c = gelu_approx(a) * b, where: + gelu_approx(a) = 0.5 * a * (1 + tanh(a)) + tanh(a) = (exp(2a) - 1) / (exp(2a) + 1) + + Note: This is a simplified tanh-based GELU (without the polynomial + inner argument used in the full approximation). It matches what the + PTO kernel computes using only tile-tile operations. + """ + a_f32 = a.float() + gelu_a = 0.5 * a_f32 * (1.0 + torch.tanh(a_f32)) + return gelu_a.to(a.dtype) * b + + +def test_geglu(lib_path, block_dim=24): + device = get_test_device() + torch.npu.set_device(device) + + geglu = load_lib(lib_path=lib_path, block_dim=block_dim) + + torch.manual_seed(0) + dtype = torch.float16 + batch_list = [1, 4, 22, 65] + n_cols_list = [128, 256, 512, 1024, 2048, 4096, 8192, 16384] + + results = [] + for batch in batch_list: + for n_cols in n_cols_list: + # Use small range to stay within fp16 exp range (avoid overflow). + a = torch.randn(batch, n_cols, device=device, dtype=dtype).clamp(-4, 4) + b = torch.randn(batch, n_cols, device=device, dtype=dtype) + c = torch.empty(batch, n_cols, device=device, dtype=dtype) + + y_ref = geglu_ref(a, b) + geglu(a, b, c, batch, n_cols) + torch.npu.synchronize() + + is_match = True + detail = "" + try: + torch.testing.assert_close(c, y_ref, rtol=1e-2, atol=1e-2) + except AssertionError as err: + is_match = False + detail = str(err).strip() if str(err) else "assert_close failed" + + status = "match" if is_match else "mismatch" + print(f"[{status}] batch={batch}, n_cols={n_cols}, lib={lib_path}") + if detail: + print(" detail:") + print(detail) + results.append((batch, n_cols, status, detail)) + + print(f"\ndetailed summary for {lib_path}:") + for batch, n_cols, status, detail in results: + msg = f" batch={batch}, n_cols={n_cols}, status={status}" + print(msg) + if detail: + print(" detail:") + print(detail) + return results + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--lib", + default="./geglu_lib.so", + help="Path to the shared library generated by compile.sh.", + ) + parser.add_argument( + "--block-dim", + type=int, + default=24, + help="Kernel blockDim (default: 24).", + ) + args = parser.parse_args() + test_geglu(args.lib, block_dim=args.block_dim) diff --git a/examples/aot/relu_dynamic_multicore/.gitignore b/examples/aot/activations/relu_dynamic_multicore/.gitignore similarity index 100% rename from examples/aot/relu_dynamic_multicore/.gitignore rename to examples/aot/activations/relu_dynamic_multicore/.gitignore diff --git a/examples/aot/relu_dynamic_multicore/README.md b/examples/aot/activations/relu_dynamic_multicore/README.md similarity index 100% rename from examples/aot/relu_dynamic_multicore/README.md rename to examples/aot/activations/relu_dynamic_multicore/README.md diff --git a/examples/aot/relu_dynamic_multicore/caller.cpp b/examples/aot/activations/relu_dynamic_multicore/caller.cpp similarity index 100% rename from examples/aot/relu_dynamic_multicore/caller.cpp rename to examples/aot/activations/relu_dynamic_multicore/caller.cpp diff --git a/examples/aot/relu_dynamic_multicore/compile.sh b/examples/aot/activations/relu_dynamic_multicore/compile.sh similarity index 100% rename from examples/aot/relu_dynamic_multicore/compile.sh rename to examples/aot/activations/relu_dynamic_multicore/compile.sh diff --git a/examples/aot/relu_dynamic_multicore/relu_builder.py b/examples/aot/activations/relu_dynamic_multicore/relu_builder.py similarity index 71% rename from examples/aot/relu_dynamic_multicore/relu_builder.py rename to examples/aot/activations/relu_dynamic_multicore/relu_builder.py index dc6acbd6..3e84659b 100644 --- a/examples/aot/relu_dynamic_multicore/relu_builder.py +++ b/examples/aot/activations/relu_dynamic_multicore/relu_builder.py @@ -1,5 +1,5 @@ -from ptodsl import to_ir_module -import ptodsl.language as pto +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s def build(): @@ -29,39 +29,45 @@ def meta_data(): "tile_w": tile_w, } - const = pto.const + const = s.const @to_ir_module(meta_data=meta_data) - def sync_kernel_dyn(arg0: "ptr_type", arg1: "ptr_type", argN: "index_dtype") -> None: + def sync_kernel_dyn( + arg0: "ptr_type", arg1: "ptr_type", argN: "index_dtype" + ) -> None: with pto.vector_section(): c0 = const(0) c1 = const(1) c_tile_w = const(tile_w) - total_elements = pto.index_cast(argN) + total_elements = s.index_cast(argN) - num_blocks = pto.index_cast(pto.get_block_num()) - num_el_per_core = pto.ceil_div(total_elements, num_blocks) + num_blocks = s.index_cast(pto.get_block_num()) + num_el_per_core = s.ceil_div(total_elements, num_blocks) # Per-core range: [core_start, core_end) - bid = pto.index_cast(pto.get_block_idx()) + bid = s.index_cast(pto.get_block_idx()) core_start = bid * num_el_per_core core_end_unclamped = core_start + num_el_per_core - core_end = pto.min_u(core_end_unclamped, total_elements) + core_end = s.min_u(core_end_unclamped, total_elements) core_len = core_end - core_start # Per-core number of tiles: ceil(core_len / tile_w). - num_tiles = pto.ceil_div(core_len, c_tile_w) + num_tiles = s.ceil_div(core_len, c_tile_w) # GM tensors shape N with stride 1. - tv0 = pto.as_tensor(tensor_type, ptr=arg0, shape=[total_elements], strides=[c1]) - tv1 = pto.as_tensor(tensor_type, ptr=arg1, shape=[total_elements], strides=[c1]) - - for i in pto.for_range(c0, num_tiles, c1): + tv0 = pto.as_tensor( + tensor_type, ptr=arg0, shape=[total_elements], strides=[c1] + ) + tv1 = pto.as_tensor( + tensor_type, ptr=arg1, shape=[total_elements], strides=[c1] + ) + + for i in pto.range(c0, num_tiles, c1): offset_tile = i * c_tile_w offset_total = core_start + offset_tile remaining_core = core_end - offset_total - valid_len = pto.min_u(remaining_core, c_tile_w) + valid_len = s.min_u(remaining_core, c_tile_w) # Keep per-iteration tile alloc to match original behavior. tb0 = pto.alloc_tile(tile_type, valid_row=c1, valid_col=valid_len) @@ -82,10 +88,11 @@ def sync_kernel_dyn(arg0: "ptr_type", arg1: "ptr_type", argN: "index_dtype") -> ) pto.load(sv0, tb0) - pto.relu(tb0, tb1) + tile.relu(tb0, tb1) pto.store(tb1, sv1) return sync_kernel_dyn + if __name__ == "__main__": print(build()) diff --git a/examples/aot/relu_dynamic_multicore/run_relu.py b/examples/aot/activations/relu_dynamic_multicore/run_relu.py similarity index 80% rename from examples/aot/relu_dynamic_multicore/run_relu.py rename to examples/aot/activations/relu_dynamic_multicore/run_relu.py index 99465bb2..9281ea50 100644 --- a/examples/aot/relu_dynamic_multicore/run_relu.py +++ b/examples/aot/activations/relu_dynamic_multicore/run_relu.py @@ -35,7 +35,7 @@ def load_lib(lib_path, block_dim, check_type=True): def relu_func(x, y, n, block_dim=block_dim, stream_ptr=None): if stream_ptr is None: - stream_ptr= torch.npu.current_stream()._as_parameter_ + stream_ptr = torch.npu.current_stream()._as_parameter_ lib.call_kernel( block_dim, @@ -54,13 +54,12 @@ def test_relu(verbose=True): torch.npu.set_device(device) dtype = torch.float32 - # allocate a bigger buffer than the actual number of elements to test the padding behavior shape = [1, 2 * 128] for BLOCK_DIM in range(1, 21): relu_kernel = load_lib("relu_lib.so", block_dim=BLOCK_DIM) - print(BLOCK_DIM) - for num_elements in [3,7,13,97,143, 2*128]: + print(BLOCK_DIM) + for num_elements in [3, 7, 13, 97, 143, 2 * 128]: x = torch.rand(shape, device=device, dtype=dtype) - 0.5 y = torch.full(shape, -10, device=device, dtype=dtype) relu_kernel(x, y, n=num_elements) @@ -73,17 +72,22 @@ def test_relu(verbose=True): step = 1 for i in range(0, shape[0]): for j in range(0, shape[1], step): - if correct[i, j:j+step].all(): - print('X', end='') + if correct[i, j : j + step].all(): + print("X", end="") else: - print('.', end='') + print(".", end="") if j == num_elements - 1: - print('|', end='') - print('|') + print("|", end="") + print("|") - torch.testing.assert_close(y.flatten()[:num_elements], y_ref.flatten()[:num_elements]) + torch.testing.assert_close( + y.flatten()[:num_elements], y_ref.flatten()[:num_elements] + ) # Make sure we didn't write past the end of the buffer - torch.testing.assert_close(y.flatten()[num_elements:], torch.full_like(y.flatten()[num_elements:], -10)) + torch.testing.assert_close( + y.flatten()[num_elements:], + torch.full_like(y.flatten()[num_elements:], -10), + ) print(f"RELU test pass for shape {shape}! using {BLOCK_DIM} cores") diff --git a/examples/aot/matmul_dynbatch_multicore/.gitignore b/examples/aot/batch_matmul/matmul_dynbatch_multicore/.gitignore similarity index 100% rename from examples/aot/matmul_dynbatch_multicore/.gitignore rename to examples/aot/batch_matmul/matmul_dynbatch_multicore/.gitignore diff --git a/examples/aot/matmul_dynbatch_multicore/README.md b/examples/aot/batch_matmul/matmul_dynbatch_multicore/README.md similarity index 100% rename from examples/aot/matmul_dynbatch_multicore/README.md rename to examples/aot/batch_matmul/matmul_dynbatch_multicore/README.md diff --git a/examples/aot/matmul_dynbatch_multicore/caller.cpp b/examples/aot/batch_matmul/matmul_dynbatch_multicore/caller.cpp similarity index 100% rename from examples/aot/matmul_dynbatch_multicore/caller.cpp rename to examples/aot/batch_matmul/matmul_dynbatch_multicore/caller.cpp diff --git a/examples/aot/matmul_dynbatch_multicore/compile.sh b/examples/aot/batch_matmul/matmul_dynbatch_multicore/compile.sh similarity index 100% rename from examples/aot/matmul_dynbatch_multicore/compile.sh rename to examples/aot/batch_matmul/matmul_dynbatch_multicore/compile.sh diff --git a/examples/aot/matmul_dynbatch_multicore/matmul_builder.py b/examples/aot/batch_matmul/matmul_dynbatch_multicore/matmul_builder.py similarity index 54% rename from examples/aot/matmul_dynbatch_multicore/matmul_builder.py rename to examples/aot/batch_matmul/matmul_dynbatch_multicore/matmul_builder.py index c3b7f96d..28015228 100644 --- a/examples/aot/matmul_dynbatch_multicore/matmul_builder.py +++ b/examples/aot/batch_matmul/matmul_dynbatch_multicore/matmul_builder.py @@ -1,7 +1,7 @@ from mlir.ir import IntegerType -from ptodsl import to_ir_module -import ptodsl.language as pto +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s def build( @@ -29,14 +29,26 @@ def meta_data(): tile_view_out = pto.SubTensorType(shape=[M, N], dtype=dtype) tile_view_bias = pto.SubTensorType(shape=[1, N], dtype=dtype) - tile_buf_aMat = pto.TileBufType(shape=[M, BASEK], dtype=dtype, memory_space="MAT") - tile_buf_bMat = pto.TileBufType(shape=[BASEK, N], dtype=dtype, memory_space="MAT") - tile_buf_biasData = pto.TileBufType(shape=[1, N], dtype=dtype, memory_space="MAT") - - tile_buf_aTile = pto.TileBufType(shape=[M, BASEK], dtype=dtype, memory_space="LEFT") - tile_buf_bTile = pto.TileBufType(shape=[BASEK, N], dtype=dtype, memory_space="RIGHT") + tile_buf_aMat = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="MAT" + ) + tile_buf_bMat = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="MAT" + ) + tile_buf_biasData = pto.TileBufType( + shape=[1, N], dtype=dtype, memory_space="MAT" + ) + + tile_buf_aTile = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="LEFT" + ) + tile_buf_bTile = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="RIGHT" + ) tile_buf_cTile = pto.TileBufType(shape=[M, N], dtype=dtype, memory_space="ACC") - tile_buf_biasTile = pto.TileBufType(shape=[1, N], dtype=dtype, memory_space="BIAS") + tile_buf_biasTile = pto.TileBufType( + shape=[1, N], dtype=dtype, memory_space="BIAS" + ) return { "ptr_type": ptr_dtype, @@ -56,7 +68,7 @@ def meta_data(): "tile_buf_biasTile": tile_buf_biasTile, } - const = pto.const + const = s.const @to_ir_module(meta_data=meta_data) def RunTMATMULSplitK( @@ -78,20 +90,28 @@ def RunTMATMULSplitK( cTileM = const(M) cTileN = const(N) - batch = pto.index_cast(batch_i32) + batch = s.index_cast(batch_i32) cBM = batch * cM - num_blocks = pto.index_cast(pto.get_block_num()) - batches_per_core = pto.ceil_div(batch, num_blocks) - bid = pto.index_cast(pto.get_block_idx()) + num_blocks = s.index_cast(pto.get_block_num()) + batches_per_core = s.ceil_div(batch, num_blocks) + bid = s.index_cast(pto.get_block_idx()) b_start = bid * batches_per_core b_end_unclamped = b_start + batches_per_core - b_end = pto.min_u(b_end_unclamped, batch) - - tvA = pto.as_tensor(tensor_type, ptr=a_ptr, shape=[cBM, cK], strides=[cK, c1]) - tvB = pto.as_tensor(tensor_type, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1]) - tvOut = pto.as_tensor(tensor_type, ptr=out_ptr, shape=[cBM, cN], strides=[cN, c1]) - tvBias = pto.as_tensor(tensor_type, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1]) + b_end = s.min_u(b_end_unclamped, batch) + + tvA = pto.as_tensor( + tensor_type, ptr=a_ptr, shape=[cBM, cK], strides=[cK, c1] + ) + tvB = pto.as_tensor( + tensor_type, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1] + ) + tvOut = pto.as_tensor( + tensor_type, ptr=out_ptr, shape=[cBM, cN], strides=[cN, c1] + ) + tvBias = pto.as_tensor( + tensor_type, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1] + ) aMatTile = pto.alloc_tile(tile_buf_aMat) bMatTile = pto.alloc_tile(tile_buf_bMat) @@ -101,14 +121,29 @@ def RunTMATMULSplitK( cTile = pto.alloc_tile(tile_buf_cTile) biasTile = pto.alloc_tile(tile_buf_biasTile) - for b_idx in pto.for_range(b_start, b_end, c1): + for b_idx in pto.range(b_start, b_end, c1): row_off = b_idx * cM - for i in pto.for_range(c0, cIter, c1): + for i in pto.range(c0, cIter, c1): kOff = i * cBASEK - svA = pto.slice_view(tile_view_a, source=tvA, offsets=[row_off, kOff], sizes=[cTileM, cBASEK]) - svB = pto.slice_view(tile_view_b, source=tvB, offsets=[kOff, c0], sizes=[cBASEK, cTileN]) - svBias = pto.slice_view(tile_view_bias, source=tvBias, offsets=[c0, c0], sizes=[c1, cTileN]) + svA = pto.slice_view( + tile_view_a, + source=tvA, + offsets=[row_off, kOff], + sizes=[cTileM, cBASEK], + ) + svB = pto.slice_view( + tile_view_b, + source=tvB, + offsets=[kOff, c0], + sizes=[cBASEK, cTileN], + ) + svBias = pto.slice_view( + tile_view_bias, + source=tvBias, + offsets=[c0, c0], + sizes=[c1, cTileN], + ) pto.load(svA, aMatTile) pto.load(svB, bMatTile) @@ -117,32 +152,37 @@ def RunTMATMULSplitK( pto.record_wait_pair("LOAD", "MOV_M2L", event_id=0) - pto.mov(aMatTile, aTile) - pto.mov(bMatTile, bTile) + tile.mov(aMatTile, aTile) + tile.mov(bMatTile, bTile) with pto.if_context(isBias): - pto.mov(biasDataTile, biasTile) + tile.mov(biasDataTile, biasTile) pto.record_wait_pair("MOV_M2L", "MATMUL", event_id=0) - is_i0 = pto.eq(i, c0) + is_i0 = s.eq(i, c0) def _first_iter(): pto.cond( isBias, - lambda: pto.matmul_bias(aTile, bTile, biasTile, cTile), - lambda: pto.matmul(aTile, bTile, cTile), + lambda: tile.matmul_bias(aTile, bTile, biasTile, cTile), + lambda: tile.matmul(aTile, bTile, cTile), ) pto.cond( is_i0, _first_iter, - lambda: pto.matmul_acc(cTile, aTile, bTile, cTile), + lambda: tile.matmul_acc(cTile, aTile, bTile, cTile), ) pto.record_wait_pair("MATMUL", "LOAD", event_id=0) pto.record_wait_pair("MATMUL", "STORE_ACC", event_id=0) - svOut = pto.slice_view(tile_view_out, source=tvOut, offsets=[row_off, c0], sizes=[cTileM, cTileN]) + svOut = pto.slice_view( + tile_view_out, + source=tvOut, + offsets=[row_off, c0], + sizes=[cTileM, cTileN], + ) pto.store(cTile, svOut) pto.record_wait_pair("STORE_ACC", "MATMUL", event_id=0) @@ -150,4 +190,4 @@ def _first_iter(): if __name__ == "__main__": - print(build()) \ No newline at end of file + print(build()) diff --git a/examples/aot/matmul_dynbatch_multicore/run_matmul.py b/examples/aot/batch_matmul/matmul_dynbatch_multicore/run_matmul.py similarity index 93% rename from examples/aot/matmul_dynbatch_multicore/run_matmul.py rename to examples/aot/batch_matmul/matmul_dynbatch_multicore/run_matmul.py index b4c8d79e..197087a1 100644 --- a/examples/aot/matmul_dynbatch_multicore/run_matmul.py +++ b/examples/aot/batch_matmul/matmul_dynbatch_multicore/run_matmul.py @@ -11,11 +11,7 @@ def torch_to_ctypes(tensor): def load_lib(lib_path): lib = ctypes.CDLL(lib_path) - def matmul_func( - c, a, b, batch_size, - block_dim, - stream_ptr=None - ): + def matmul_func(c, a, b, batch_size, block_dim, stream_ptr=None): if stream_ptr is None: stream_ptr = torch.npu.current_stream()._as_parameter_ lib.call_kernel( diff --git a/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/.gitignore b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/.gitignore new file mode 100644 index 00000000..04529667 --- /dev/null +++ b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/.gitignore @@ -0,0 +1 @@ +mul.cpp diff --git a/examples/aot/matmul_dynbatch_multicore_opt/README.md b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/README.md similarity index 100% rename from examples/aot/matmul_dynbatch_multicore_opt/README.md rename to examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/README.md diff --git a/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/caller.cpp b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/caller.cpp new file mode 100644 index 00000000..fff32469 --- /dev/null +++ b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/caller.cpp @@ -0,0 +1,13 @@ +#include "mul.cpp" + +extern "C" void call_kernel( + uint32_t blockDim, void* stream, + uint8_t* c, uint8_t* a, uint8_t* b, uint32_t batch_size) +{ + RunTMATMULSplitK<<>>( + reinterpret_cast(c), + reinterpret_cast(a), + reinterpret_cast(b), + nullptr, false, batch_size + ); +} diff --git a/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/compile.sh b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/compile.sh new file mode 100755 index 00000000..0cd49d16 --- /dev/null +++ b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/compile.sh @@ -0,0 +1,12 @@ +rm mul.cpp matmul_kernel.so + +python ./matmul_dsl.py | ptoas > mul.cpp + +bisheng -fPIC -shared -xcce -O2 -std=c++17 \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -I${ASCEND_TOOLKIT_HOME}/include \ + --cce-soc-version=Ascend910B2 \ + --cce-soc-core-type=CubeCore \ + -I/mounted_home/pto-isa/include \ + ./caller.cpp \ + -o ./matmul_kernel.so diff --git a/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/matmul_dsl.py b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/matmul_dsl.py new file mode 100644 index 00000000..8ec0dfc9 --- /dev/null +++ b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/matmul_dsl.py @@ -0,0 +1,196 @@ +from mlir.ir import IntegerType + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + + +def build(M=128, K=128, N=128): + def meta_data(): + dtype = pto.float16 + dtype_acc_tile = pto.float32 + ptr_type = pto.PtrType(dtype) + i32 = pto.int32 + i1 = IntegerType.get_signless(1) + + tensor_type = pto.TensorType(rank=2, dtype=dtype) + tensor_type3d = pto.TensorType(rank=3, dtype=dtype) + + tile_view_a = pto.SubTensorType(shape=[M, K], dtype=dtype) + tile_view_b = pto.SubTensorType(shape=[K, N], dtype=dtype) + tile_view_c = pto.SubTensorType(shape=[M, N], dtype=dtype) + tile_buf_aMat = pto.TileBufType(shape=[M, K], dtype=dtype, memory_space="MAT") + tile_buf_bMat = pto.TileBufType(shape=[K, N], dtype=dtype, memory_space="MAT") + tile_buf_aTile = pto.TileBufType(shape=[M, K], dtype=dtype, memory_space="LEFT") + tile_buf_bTile = pto.TileBufType( + shape=[K, N], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_cTile = pto.TileBufType( + shape=[M, N], dtype=dtype_acc_tile, memory_space="ACC" + ) + # TODO: Get rid of this? + return locals() + + const = s.const + + # Until we have set_dyn_flag with event_id as SSA values + # event_id can be dynamic SSA value + # https://github.com/zhangstevenunity/PTOAS/pull/176 + def record_event(src, dst, event_id): + pto.cond( + event_id == const(0), + lambda: pto.record_event(src, dst, event_id=0), + lambda: pto.record_event(src, dst, event_id=1), + ) + + def wait_event(src, dst, event_id): + pto.cond( + event_id == const(0), + lambda: pto.wait_event(src, dst, event_id=0), + lambda: pto.wait_event(src, dst, event_id=1), + ) + + @to_ir_module(meta_data=meta_data) + def RunTMATMULSplitK( + out_ptr: "ptr_type", + a_ptr: "ptr_type", + b_ptr: "ptr_type", + bias_ptr: "ptr_type", + isBias: "i1", + batch_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + c2 = const(2) + cM = const(M) + cK = const(K) + cN = const(N) + batch = s.index_cast(batch_i32) + + num_blocks = s.index_cast(pto.get_block_num()) + # TODO round robin + batches_per_core = s.ceil_div(batch, num_blocks) + bid = s.index_cast(pto.get_block_idx()) + b_start = bid * batches_per_core + b_end_unclamped = b_start + batches_per_core + b_end = s.min_u(b_end_unclamped, batch) + + # TODO: if no batched assigned to this core, early return + + tvA = pto.as_tensor( + tensor_type3d, + ptr=a_ptr, + shape=[batch, cM, cK], + strides=[cK * cM, cK, c1], + ) + tvC = pto.as_tensor( + tensor_type3d, + ptr=out_ptr, + shape=[batch, cM, cN], + strides=[cM * cN, cN, c1], + ) + tvB = pto.as_tensor( + tensor_type, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1] + ) + + # TODO: pre-fetch more than two tiles into L1 + NUM_BUFFERS = 2 + aMatTiles = [pto.alloc_tile(tile_buf_aMat) for _ in range(NUM_BUFFERS)] + bMatTile = pto.alloc_tile(tile_buf_bMat) + # Ping and pong buffers in L0A/C + aTiles = [pto.alloc_tile(tile_buf_aTile) for _ in range(NUM_BUFFERS)] + cTiles = [pto.alloc_tile(tile_buf_cTile) for _ in range(NUM_BUFFERS)] + bTile = pto.alloc_tile(tile_buf_bTile) + + # Put B in L0B + svB = pto.slice_view( + tile_view_b, source=tvB, offsets=[c0, c0], sizes=[cK, cN] + ) + pto.load(svB, bMatTile) + pto.record_wait_pair("LOAD", "MOV_M2L", event_id=0) + tile.mov(bMatTile, bTile) + # TODO: wait here so we can use full l1 memory later for A. + + # load in the first tile from GM->L1 + svA = pto.slice_view( + tile_view_a, source=tvA, offsets=[b_start, c0, c0], sizes=[c1, cM, cK] + ) + curr = c1 - (b_start % c2) + pto.cond( + curr == c1, + lambda: pto.load(svA, aMatTiles[0]), + lambda: pto.load(svA, aMatTiles[1]), + ) + record_event("LOAD", "MOV_M2L", event_id=curr) + + # TODO: fix wait events if batch size is 1/2 + # signal to LOAD that L1 can be overwritten + pto.record_event("MOV_M2L", "LOAD", event_id=[0, 1]) + # signal to MOV that L0 can be overwritten + pto.record_event("MATMUL", "MOV_M2L", event_id=[0, 1]) + # signal to MATMUL that it can overwrite L0C + pto.record_event("STORE_ACC", "MATMUL", event_id=[0, 1]) + + for b_idx in pto.range(b_start, b_end, c1): + curr = b_idx % c2 + svA = pto.slice_view( + tile_view_a, + source=tvA, + offsets=[b_idx + c1, c0, c0], + sizes=[c1, cM, cK], + ) + svC = pto.slice_view( + tile_view_c, source=tvC, offsets=[b_idx, c0, c0], sizes=[c1, cM, cN] + ) + + ########## Load tile A for iteration i+1 from GM -> L1 + wait_event("MOV_M2L", "LOAD", event_id=curr) + with pto.if_context(b_idx + c1 < b_end): + pto.cond( + curr == c1, + lambda: pto.load(svA, aMatTiles[0]), + lambda: pto.load(svA, aMatTiles[1]), + ) + record_event("LOAD", "MOV_M2L", event_id=curr) + + ########## Move A1 and A2 into L0A + wait_event("LOAD", "MOV_M2L", event_id=c1 - curr) + wait_event("MATMUL", "MOV_M2L", event_id=curr) + pto.cond( + curr == c0, + lambda: tile.mov(aMatTiles[0], aTiles[0]), + lambda: tile.mov(aMatTiles[1], aTiles[1]), + ) + with pto.if_context(b_idx + c2 < b_end): + record_event("MOV_M2L", "LOAD", event_id=curr) + record_event("MOV_M2L", "MATMUL", event_id=curr) + + ########## Perform matmul + wait_event("MOV_M2L", "MATMUL", event_id=curr) + wait_event("STORE_ACC", "MATMUL", event_id=curr) + pto.cond( + curr == c0, + lambda: tile.matmul(aTiles[0], bTile, cTiles[0]), + lambda: tile.matmul(aTiles[1], bTile, cTiles[1]), + ) + record_event("MATMUL", "STORE_ACC", event_id=curr) + with pto.if_context(b_idx + c2 < b_end): + record_event("MATMUL", "MOV_M2L", event_id=curr) + + ######### Store + wait_event("MATMUL", "STORE_ACC", event_id=curr) + pto.cond( + curr == c0, + lambda: pto.store(cTiles[0], svC), + lambda: pto.store(cTiles[1], svC), + ) + with pto.if_context(b_idx + c2 < b_end): + record_event("STORE_ACC", "MATMUL", event_id=curr) + + pto.barrier("LOAD") + + return RunTMATMULSplitK + + +if __name__ == "__main__": + print(build()) diff --git a/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/matmul_ref.cpp b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/matmul_ref.cpp new file mode 100644 index 00000000..e0848f5d --- /dev/null +++ b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/matmul_ref.cpp @@ -0,0 +1,133 @@ +#include "pto/pto-inst.hpp" +using namespace pto; +__global__ AICORE void RunTMATMULSplitK(__gm__ half *v1, __gm__ half *v2, __gm__ half *v3, __gm__ half *v4, bool v5, + int32_t v6) { + unsigned v7 = 16384; + unsigned v8 = 128; + unsigned v9 = 1; + unsigned v10 = 0; + int32_t v11 = 0; + int32_t v12 = 1; + int32_t v13 = 2; + int32_t v14 = 128; + int32_t v15 = 16384; + int64_t v16 = 32768; + int64_t v17 = 65536; + int64_t v18 = 0; + using T = float; + +#if defined(__DAV_CUBE__) + int64_t v19 = get_block_num(); + int32_t v20 = (int32_t)((int64_t)v19); + int32_t v21 = v6 / v20; + int32_t v22 = v6 % v20 != v11 && v6 < v11 == v20 < v11 ? v21 + v12 : v21; + int64_t v23 = get_block_idx(); + int32_t v24 = (int32_t)((uint32_t)((int32_t)(int64_t)v23) * (uint32_t)v22); + int32_t v25 = (int32_t)((uint32_t)v24 + (uint32_t)v22); + Tile A1_l1; + TASSIGN(A1_l1, v16); + Tile A2_l1; + TASSIGN(A2_l1, v17); + + Tile A1_l0; + TASSIGN(A1_l0, v18); + Tile A2_l0; + TASSIGN(A2_l0, v16); + + Tile C1_l0; + TASSIGN(C1_l0, v18); + Tile C2_l0; + TASSIGN(C2_l0, v17); + + Tile v28; + TASSIGN(v28, v18); + Tile B_l0; + TASSIGN(B_l0, v18); + pto::Shape<1, 1, 1, 128, 128> v34 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v35 = pto::Stride<16384, 16384, 16384, 128, 1>(); + + using GMType = + GlobalTensor, pto::Stride<16384, 16384, 16384, 128, 1>, pto::Layout::ND>; + GMType v36 = GMType(v3 + (v10 + v10 * (unsigned)v14 + v10 * (unsigned)v12), v34, v35); + TLOAD(v28, v36); + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + TMOV(B_l0, v28); + + pto::Shape<1, 1, 1, 128, 128> v39 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v40 = pto::Stride<16384, 16384, 16384, 128, 1>(); + pto::Shape<1, 1, 1, 128, 128> v49 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v50 = pto::Stride<16384, 16384, 16384, 128, 1>(); + pto::Shape<1, 1, 1, 128, 128> v43 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v44 = pto::Stride<16384, 16384, 16384, 128, 1>(); + pto::Shape<1, 1, 1, 128, 128> v46 = pto::Shape<1, 1, 1, 128, 128>(); + pto::Stride<16384, 16384, 16384, 128, 1> v47 = pto::Stride<16384, 16384, 16384, 128, 1>(); + + int end = ((size_t)((uint32_t)v25 < (uint32_t)v6 ? v25 : v6)); + // i-2 + int curr = v24 & 1; + set_flag(PIPE_MTE1, PIPE_MTE2, curr); // set(1) + set_flag(PIPE_M, PIPE_MTE1, curr); // set(3) + set_flag(PIPE_FIX, PIPE_M, curr); // set(4) + + // i-1 + // must load the first tile from GM->l1 here since the loop always loads for + // next iteration + GMType A_gm_first = GMType(v2 + v24 * v15, v39, v40); + // this is iteration i-1, in this case -1 + curr = 1 - curr; // since v24 can start at odd/even i must load the right tile + TLOAD(curr == 1 ? A1_l1 : A2_l1, A_gm_first); + set_flag(PIPE_MTE2, PIPE_MTE1, curr); // set(2) tell MTE1 that MTE2 finished. + + set_flag(PIPE_MTE1, PIPE_MTE2, curr); // set(1) + set_flag(PIPE_M, PIPE_MTE1, curr); // set(3) + set_flag(PIPE_FIX, PIPE_M, curr); // set(4) + + for (size_t i = v24; i < end; i += 1) { + curr = i & 1; + // Global memory for A tiles + GMType v45 = GMType(v2 + (i + 1) * v15, v43, v44); + // GM tile C_1 and C_2 + GMType v48 = GMType(v1 + i * v15, v46, v47); + + // Start loading the tile used in matmul at iteration i+1 + wait_flag(PIPE_MTE1, PIPE_MTE2, curr); // (1, i-2) wait until the MOV at i-2 has completed + if (i + 1 < end) { + TLOAD(curr == 1 ? A1_l1 : A2_l1, v45); + set_flag(PIPE_MTE2, PIPE_MTE1, curr); // set(2, i+1) notify the mov below in iteration i+1 that the load completed + } + + // mov + wait_flag(PIPE_MTE2, PIPE_MTE1, 1 - curr); // (2, i-1) last iteration loaded the tile into l1, so + // for us to move to l0 we wait for last it + wait_flag(PIPE_M, PIPE_MTE1, curr); // (3, i-2) make sure the matmul from + // i-2 finished so we can overwrite l0A + TMOV(curr == 0 ? A1_l0 : A2_l0, curr == 0 ? A1_l1 : A2_l1); + if (i + 2 < end) { + set_flag(PIPE_MTE1, PIPE_MTE2, curr); // set(1, i+2) notify load at iteration i+2 that it's ready + } + set_flag(PIPE_MTE1, PIPE_M, curr); // set(5, i) simply notify matmul at it. i it is ready. + + // matmul + wait_flag(PIPE_FIX, PIPE_M, curr); // (4, i-2) wait until the STORE at it. + // i-2 has written back from L0C + wait_flag(PIPE_MTE1, PIPE_M, curr); // (5, i) need the tile that is moved into L0A at iteration i + TMATMUL(curr == 0 ? C1_l0 : C2_l0, curr == 0 ? A1_l0 : A2_l0, B_l0); + set_flag(PIPE_M, PIPE_FIX, curr); // set(6, i) notify store in this + // iteration i, that matmul is done + if (i + 2 < end) { + set_flag(PIPE_M, PIPE_MTE1, curr); // set(3, i+2) notify mov in iteration i+2, that matmul is done + } + + // store + wait_flag(PIPE_M, PIPE_FIX, curr); // (6, i) wait for matmul in it. i to be done + TSTORE(v48, curr == 0 ? C1_l0 : C2_l0); + if (i + 2 < end) { + set_flag(PIPE_FIX, PIPE_M, curr); // set(4, i+2) notify matmul in i+2 that store is complete + } + } + +#endif // __DAV_CUBE__ + + return; +} diff --git a/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/run_matmul.py b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/run_matmul.py new file mode 100644 index 00000000..cdb0929a --- /dev/null +++ b/examples/aot/batch_matmul/matmul_dynbatch_multicore_2buf/run_matmul.py @@ -0,0 +1,204 @@ +from typing import Callable, List, Literal, Union +import ctypes +import time +import argparse + +from ptodsl.test_util import get_test_device +from ptodsl import do_bench + +import torch +import torch_npu + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def _dtype_nbytes(dtype: torch.dtype) -> int: + return torch.empty((), dtype=dtype).element_size() + + +def matmul_flops(batch_size: int, m: int, k: int, n: int) -> int: + return 2 * batch_size * m * k * n + + +def matmul_io_bytes(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> int: + # Simple traffic model: read A + read B + write C. + elt = _dtype_nbytes(a.dtype) + return (a.numel() + b.numel() + c.numel()) * elt + + +def benchmark( + fn, + *, + flops: int | None = None, + io_bytes: int | None = None, +) -> dict: + avg_s = do_bench(fn, unit="s", flush_cache=True) + stats = {"avg_ms": avg_s * 1e3} + if flops is not None: + stats["tflops"] = (flops / avg_s) / 1e12 + if io_bytes is not None: + stats["gbps"] = (io_bytes / avg_s) / 1e9 + return stats + + +def print_benchmark(stats: dict) -> None: + parts = [f"{stats['name']}: {stats['avg_ms']:.3f} ms"] + if "tflops" in stats: + parts.append(f"{stats['tflops']:.2f} TFLOP/s") + if "gbps" in stats: + parts.append(f"{stats['gbps']:.2f} GB/s (A+B+C)") + print(" | ".join(parts)) + + +def load_lib(lib_path): + lib = ctypes.CDLL(lib_path) + + def matmul_func(c, a, b, batch_size, block_dim, stream_ptr=None): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(c), + torch_to_ctypes(a), + torch_to_ctypes(b), + ctypes.c_uint32(batch_size), + ) + + return matmul_func + + +def plot_benchmark(): + import matplotlib.pyplot as plt + + device = get_test_device() + torch.set_default_device(device) + torch.npu.set_device(device) + dtype = torch.float16 + torch.manual_seed(0) + + matmul_func = load_lib("./matmul_kernel.so") + + pto_results, torch_results, pto2_results, pto3_results = [], [], [], [] + m, k, n = 128, 128, 128 + batches = list(range(24 * 2, 8000, 24 * 2)) + blk = [24, 1, 6] + for i in batches: + bs = i + a = torch.rand((bs, m, k), device=device, dtype=dtype) + b = torch.rand((k, n), device=device, dtype=dtype) + c = torch.zeros((bs, m, n), device=device, dtype=dtype) + + # correctness check + matmul_func(c, a, b, batch_size=bs, block_dim=24) + torch.npu.synchronize() + c_ref = torch.matmul(a, b) + diff = (c - c_ref).abs().max() + # assert diff <= 1e-5, diff + if diff < 1e-5: + print(".", end="") + else: + print(f"failed at shape: {a.shape} with {diff}") + + flops = matmul_flops(bs, m, k, n) + io_bytes = matmul_io_bytes(a, b, c) + + # run a benchmark for warmup (else first iterations are off) + benchmark(lambda: torch.matmul(a, b, out=c)) + + torch_b = benchmark( + lambda: torch.matmul(a, b, out=c), flops=flops, io_bytes=io_bytes + )["gbps"] + pto2 = benchmark( + lambda: matmul_func(c, a, b, batch_size=bs, block_dim=blk[1]), + flops=flops, + io_bytes=io_bytes, + )["gbps"] + pto3 = benchmark( + lambda: matmul_func(c, a, b, batch_size=bs, block_dim=blk[2]), + flops=flops, + io_bytes=io_bytes, + )["gbps"] + pto = benchmark( + lambda: matmul_func(c, a, b, batch_size=bs, block_dim=blk[0]), + flops=flops, + io_bytes=io_bytes, + )["gbps"] + pto_results.append(pto) + pto2_results.append(pto2) + pto3_results.append(pto3) + torch_results.append(torch_b) + print() + rel_diff = [our / their for our, their in zip(pto_results, torch_results)] + + fig, ax1 = plt.subplots(figsize=(8, 5)) + + ax1.plot(batches, pto_results, "-", label=f"pto-dsl ({blk[0]} cores)") + ax1.plot(batches, pto2_results, "-", label=f"pto-dsl ({blk[1]} cores)") + ax1.plot(batches, pto3_results, "-", label=f"pto-dsl ({blk[2]} cores)") + ax1.plot(batches, torch_results, "-", label="torch.matmul (24 cores)") + ax1.set_xlabel("Batch size") + ax1.set_ylabel("Bandwidth (Read A+B write C) (GB/s)") + ax1.grid(True, linestyle="--", alpha=0.6) + + ax2 = ax1.twinx() + ax2.plot(batches, rel_diff, "-", color="purple", label="pto-dsl / torch") + ax2.set_ylabel("Relative Performance (pto-dsl / torch)") + ax2.set_ylim(0.95 * min(rel_diff), 1.05 * max(rel_diff)) + ax2.axhline(y=1, linestyle="--", linewidth=1.0) + + dt_str = {torch.float16: "fp16", torch.float32: "fp32"}[dtype] + plt.title( + f"""pto-dsl kernel vs torch.matmul\n + @<{b.shape[0]}, {b.shape[1]}, {dt_str}>=""" + ) + + lines1, labels1 = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax1.legend(lines1 + lines2, labels1 + labels2, loc="best") + plt.tight_layout() + plt.savefig("dsl.png") + + +def correctness_verify(): + device = get_test_device() + torch.set_default_device(device) + torch.npu.set_device(device) + dtype = torch.float16 + torch.manual_seed(0) + + matmul_func = load_lib("./matmul_kernel.so") + + m, k, n = 128, 128, 128 + for blk in [1, 24]: + for bs in range(1000, 1100): + a = torch.rand((bs, m, k), device=device, dtype=dtype) + b = torch.rand((k, n), device=device, dtype=dtype) + c = torch.zeros((bs, m, n), device=device, dtype=dtype) + + matmul_func(c, a, b, batch_size=bs, block_dim=blk) + torch.npu.synchronize() + c_ref = torch.matmul(a, b) + + diff = (c - c_ref).abs().max() + # assert diff <= 1e-5, diff + if diff < 1e-5: + print(".", end="", flush=True) + else: + print( + f"#cores={blk} failed at shape: {list(a.shape)} with error:{diff}" + ) + print() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--benchmark", dest="benchmark", action="store_true", help="Enable benchmarking" + ) + args = parser.parse_args() + correctness_verify() + if args.benchmark: + plot_benchmark() diff --git a/examples/aot/matmul_dynbatch_multicore_opt/.gitignore b/examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/.gitignore similarity index 86% rename from examples/aot/matmul_dynbatch_multicore_opt/.gitignore rename to examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/.gitignore index b9455c5b..7eac319a 100644 --- a/examples/aot/matmul_dynbatch_multicore_opt/.gitignore +++ b/examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/.gitignore @@ -2,4 +2,4 @@ matmul.pto matmul.cpp matmul_kernel.so -*.png \ No newline at end of file +*.png diff --git a/examples/aot/matmul_static_singlecore/README.md b/examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/README.md similarity index 100% rename from examples/aot/matmul_static_singlecore/README.md rename to examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/README.md diff --git a/examples/aot/matmul_dynbatch_multicore_opt/caller.cpp b/examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/caller.cpp similarity index 100% rename from examples/aot/matmul_dynbatch_multicore_opt/caller.cpp rename to examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/caller.cpp diff --git a/examples/aot/matmul_dynbatch_multicore_opt/compile.sh b/examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/compile.sh similarity index 100% rename from examples/aot/matmul_dynbatch_multicore_opt/compile.sh rename to examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/compile.sh diff --git a/examples/aot/matmul_dynbatch_multicore_opt/matmul_builder.py b/examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/matmul_builder.py similarity index 78% rename from examples/aot/matmul_dynbatch_multicore_opt/matmul_builder.py rename to examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/matmul_builder.py index 61ace0b7..b711f4d6 100644 --- a/examples/aot/matmul_dynbatch_multicore_opt/matmul_builder.py +++ b/examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/matmul_builder.py @@ -1,7 +1,7 @@ # adapted from https://github.com/zhangstevenunity/PTOAS/blob/a301aa43b388d9b2e1ba0db8773b3a719e8c445b/test/samples/MatMul/tmatmulk.py -from ptodsl import to_ir_module -import ptodsl.language as pto +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s def build( @@ -29,7 +29,9 @@ def meta_data(): tile_buf_aMat = pto.TileBufType(shape=[M, K], dtype=dtype, memory_space="MAT") tile_buf_bMat = pto.TileBufType(shape=[K, N], dtype=dtype, memory_space="MAT") tile_buf_aTile = pto.TileBufType(shape=[M, K], dtype=dtype, memory_space="LEFT") - tile_buf_bTile = pto.TileBufType(shape=[K, N], dtype=dtype, memory_space="RIGHT") + tile_buf_bTile = pto.TileBufType( + shape=[K, N], dtype=dtype, memory_space="RIGHT" + ) tile_buf_cTile = pto.TileBufType(shape=[M, N], dtype=dtype, memory_space="ACC") return { @@ -49,7 +51,7 @@ def meta_data(): "tile_buf_cTile": tile_buf_cTile, } - const = pto.const + const = s.const @to_ir_module(meta_data=meta_data) def RunTMATMULSplitK( @@ -75,23 +77,27 @@ def RunTMATMULSplitK( cTileM = const(M) cTileN = const(N) - batch = pto.index_cast(batch_i32) + batch = s.index_cast(batch_i32) # Distribute batches over cores with "base + remainder" policy. - num_blocks = pto.index_cast(pto.get_block_num()) - bid = pto.index_cast(pto.get_block_idx()) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) base = batch // num_blocks rem = batch % num_blocks - lt_rem = pto.lt(bid, rem) - min_bid_rem = pto.min_u(bid, rem) + lt_rem = s.lt(bid, rem) + min_bid_rem = s.min_u(bid, rem) b_start = bid * base + min_bid_rem - length = base + pto.select(lt_rem, c1, c0) - b_end = pto.min_u(b_start + length, batch) + length = base + s.select(lt_rem, c1, c0) + b_end = s.min_u(b_start + length, batch) - tvA = pto.as_tensor(tv_a, ptr=a_ptr, shape=[batch, cM, cK], strides=[cKM, cK, c1]) + tvA = pto.as_tensor( + tv_a, ptr=a_ptr, shape=[batch, cM, cK], strides=[cKM, cK, c1] + ) tvB = pto.as_tensor(tv_b, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1]) - tvOut = pto.as_tensor(tv_out, ptr=out_ptr, shape=[batch, cM, cN], strides=[cMN, cN, c1]) + tvOut = pto.as_tensor( + tv_out, ptr=out_ptr, shape=[batch, cM, cN], strides=[cMN, cN, c1] + ) aMatTile = pto.alloc_tile(tile_buf_aMat) bMatTile = pto.alloc_tile(tile_buf_bMat) @@ -100,12 +106,14 @@ def RunTMATMULSplitK( cTile = pto.alloc_tile(tile_buf_cTile) # B is shared across batches: load once GM->L1->L0B. - svB = pto.slice_view(tile_view_b, source=tvB, offsets=[c0, c0], sizes=[cK, cTileN]) + svB = pto.slice_view( + tile_view_b, source=tvB, offsets=[c0, c0], sizes=[cK, cTileN] + ) pto.load(svB, bMatTile) pto.record_wait_pair("LOAD", "MOV_M2L", event_id=0) - pto.mov(bMatTile, bTile) + tile.mov(bMatTile, bTile) - for b_idx in pto.for_range(b_start, b_end, c1): + for b_idx in pto.range(b_start, b_end, c1): svA = pto.slice_view( tile_view_a, source=tvA, @@ -122,9 +130,9 @@ def RunTMATMULSplitK( pto.load(svA, aMatTile) pto.record_wait_pair("LOAD", "MOV_M2L", event_id=0) - pto.mov(aMatTile, aTile) + tile.mov(aMatTile, aTile) pto.record_wait_pair("MOV_M2L", "MATMUL", event_id=0) - pto.matmul(aTile, bTile, cTile) + tile.matmul(aTile, bTile, cTile) pto.record_wait_pair("MATMUL", "LOAD", event_id=0) pto.record_wait_pair("MATMUL", "STORE_ACC", event_id=0) @@ -136,4 +144,4 @@ def RunTMATMULSplitK( if __name__ == "__main__": m = build() - print(m) \ No newline at end of file + print(m) diff --git a/examples/aot/matmul_dynbatch_multicore_opt/run_matmul.py b/examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/run_matmul.py similarity index 76% rename from examples/aot/matmul_dynbatch_multicore_opt/run_matmul.py rename to examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/run_matmul.py index 31c8ba78..58f8cb55 100644 --- a/examples/aot/matmul_dynbatch_multicore_opt/run_matmul.py +++ b/examples/aot/batch_matmul/matmul_dynbatch_multicore_opt/run_matmul.py @@ -3,6 +3,7 @@ import torch import torch_npu from ptodsl.test_util import get_test_device + try: import matplotlib.pyplot as plt except ImportError: @@ -29,9 +30,7 @@ def matmul_io_bytes(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> int: # Does not include cache effects or intermediate buffers. elt = _dtype_nbytes(a.dtype) return (a.numel() + b.numel() + c.numel()) * elt - #return (a.numel() + b.numel()) * elt - - + # return (a.numel() + b.numel()) * elt def benchmark( @@ -44,7 +43,7 @@ def benchmark( flops: int | None = None, io_bytes: int | None = None, ) -> dict: - avg_s = do_bench(fn, warmup_iters=warmup, benchmark_iters=iters, unit='s') + avg_s = do_bench(fn, warmup_iters=warmup, benchmark_iters=iters, unit="s") stats = {"name": name, "iters": iters, "avg_ms": avg_s * 1e3} if flops is not None: stats["tflops"] = (flops / avg_s) / 1e12 @@ -65,11 +64,7 @@ def print_benchmark(stats: dict) -> None: def load_lib(lib_path): lib = ctypes.CDLL(lib_path) - def matmul_func( - c, a, b, batch_size, - block_dim, - stream_ptr=None - ): + def matmul_func(c, a, b, batch_size, block_dim, stream_ptr=None): if stream_ptr is None: stream_ptr = torch.npu.current_stream()._as_parameter_ lib.call_kernel( @@ -84,8 +79,6 @@ def matmul_func( return matmul_func - - def plot_benchmark(): device = get_test_device() torch.set_default_device(device) @@ -98,7 +91,7 @@ def plot_benchmark(): matmul_func = load_lib("./matmul_kernel.so") # assume defined torch.manual_seed(0) - bs, m, k, n = 24*200, 128, 128, 128 + bs, m, k, n = 24 * 200, 128, 128, 128 for blk in blk_values: a = torch.rand((bs, m, k), device=device, dtype=dtype) b = torch.rand((k, n), device=device, dtype=dtype) @@ -109,19 +102,29 @@ def plot_benchmark(): torch.npu.synchronize() c_ref = torch.matmul(a, b) diff = (c - c_ref).abs().max() - assert diff <= 1e-5, diff + assert diff <= 1e-5, diff flops = matmul_flops(bs, m, k, n) io_bytes = matmul_io_bytes(a, b, c) - torch_b = benchmark("torch.matmul", - lambda: torch.matmul(a, b, out=c), - device=device, warmup=20, iters=20, - flops=flops, io_bytes=io_bytes)['gbps'] - pto = benchmark("custom_kernel", - lambda: matmul_func(c, a, b, batch_size=bs, block_dim=blk), - device=device, warmup=20, iters=20, - flops=flops, io_bytes=io_bytes)['gbps'] + torch_b = benchmark( + "torch.matmul", + lambda: torch.matmul(a, b, out=c), + device=device, + warmup=20, + iters=20, + flops=flops, + io_bytes=io_bytes, + )["gbps"] + pto = benchmark( + "custom_kernel", + lambda: matmul_func(c, a, b, batch_size=bs, block_dim=blk), + device=device, + warmup=20, + iters=20, + flops=flops, + io_bytes=io_bytes, + )["gbps"] pto_results.append(pto) torch_results.append(torch_b) @@ -136,20 +139,21 @@ def plot_benchmark(): return # plot results - plt.figure(figsize=(8,5)) - plt.plot(blk_values, pto_results, 'o-', label='mlir') - plt.plot(blk_values, torch_results, 's-', label='torch.matmul (all cores)') - plt.xlabel('Number of cores') - plt.ylabel('Bandwidth (Read A+B write C) (GB/s)') + plt.figure(figsize=(8, 5)) + plt.plot(blk_values, pto_results, "o-", label="mlir") + plt.plot(blk_values, torch_results, "s-", label="torch.matmul (all cores)") + plt.xlabel("Number of cores") + plt.ylabel("Bandwidth (Read A+B write C) (GB/s)") plt.title( f"""Benchmark: Custom Kernel vs torch.matmul\n A: {tuple(a.shape)} B: {tuple(b.shape)}, C: {tuple(c.shape)} \n A+B+C size: {total_mb:.1f} MB""" ) - plt.grid(True, linestyle='--', alpha=0.6) + plt.grid(True, linestyle="--", alpha=0.6) plt.legend() plt.tight_layout() - plt.savefig('our.png') + plt.savefig("our.png") + if __name__ == "__main__": plot_benchmark() diff --git a/examples/aot/add_dynamic_multicore/.gitignore b/examples/aot/elementwise/add_dynamic_multicore/.gitignore similarity index 76% rename from examples/aot/add_dynamic_multicore/.gitignore rename to examples/aot/elementwise/add_dynamic_multicore/.gitignore index 34d860d5..1dde81a9 100644 --- a/examples/aot/add_dynamic_multicore/.gitignore +++ b/examples/aot/elementwise/add_dynamic_multicore/.gitignore @@ -3,4 +3,4 @@ add.pto add_lib.so add_double.cpp add_double.pto -add_double_lib.so \ No newline at end of file +add_double_lib.so diff --git a/examples/aot/add_dynamic_multicore/README.md b/examples/aot/elementwise/add_dynamic_multicore/README.md similarity index 100% rename from examples/aot/add_dynamic_multicore/README.md rename to examples/aot/elementwise/add_dynamic_multicore/README.md diff --git a/examples/aot/add_dynamic_multicore/add_builder.py b/examples/aot/elementwise/add_dynamic_multicore/add_builder.py similarity index 70% rename from examples/aot/add_dynamic_multicore/add_builder.py rename to examples/aot/elementwise/add_dynamic_multicore/add_builder.py index 72804ce4..c4a67865 100644 --- a/examples/aot/add_dynamic_multicore/add_builder.py +++ b/examples/aot/elementwise/add_dynamic_multicore/add_builder.py @@ -1,7 +1,7 @@ -from ptodsl import to_ir_module -import ptodsl.language as pto +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s -const = pto.const +const = s.const def meta_data(): @@ -43,18 +43,16 @@ def vec_add_1d_dynamic( cid = pto.get_block_idx() sub_bid = pto.get_subblock_idx() sub_bnum = pto.get_subblock_num() - cidmul = cid * sub_bnum - vid = cidmul + sub_bid num_blocks = pto.get_block_num() # Convert i64/i32 values to index for arithmetic ops. - vid_idx = pto.index_cast(vid) - num_cores = pto.index_cast(num_blocks) - total_elements = pto.index_cast(argN) + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast(num_blocks * sub_bnum) # number of vector cores + total_elements = s.index_cast(argN) - num_tiles_global = pto.ceil_div(total_elements, c_tile) - num_tiles_per_core = pto.ceil_div(num_tiles_global, num_cores) - tile_offset_this_core = vid_idx * num_tiles_per_core + num_tiles_global = s.ceil_div(total_elements, c_tile) + num_tiles_per_core = s.ceil_div(num_tiles_global, num_cores) + tile_offset_this_core = vid * num_tiles_per_core with pto.vector_section(): tv0 = pto.as_tensor(tensor_type, ptr=arg0, shape=[total_elements], strides=[c1]) @@ -71,29 +69,38 @@ def vec_add_1d_dynamic( need_truncate = tiles_end_this_core > num_tiles_global remaining_tiles = num_tiles_global - tile_offset_this_core - tiles_to_process = pto.select( + tiles_to_process = s.select( need_truncate, remaining_tiles, num_tiles_per_core ) elements_to_process = tiles_to_process * c_tile with pto.if_context(elements_to_process > c0): - for i in pto.for_range(c0, tiles_to_process, c1): + for i in pto.range(c0, tiles_to_process, c1): tile_offset_global = i + tile_offset_this_core offset_global = tile_offset_global * c_tile sv0 = pto.slice_view( - subtensor_type, source=tv0, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv0, + offsets=[offset_global], + sizes=[c_tile], ) sv1 = pto.slice_view( - subtensor_type, source=tv1, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv1, + offsets=[offset_global], + sizes=[c_tile], ) sv2 = pto.slice_view( - subtensor_type, source=tv2, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv2, + offsets=[offset_global], + sizes=[c_tile], ) pto.load(sv0, tb0) pto.load(sv1, tb1) - pto.add(tb0, tb1, tb2) + tile.add(tb0, tb1, tb2) pto.store(tb2, sv2) diff --git a/examples/aot/add_dynamic_multicore/add_double_builder.py b/examples/aot/elementwise/add_dynamic_multicore/add_double_builder.py similarity index 73% rename from examples/aot/add_dynamic_multicore/add_double_builder.py rename to examples/aot/elementwise/add_dynamic_multicore/add_double_builder.py index 436afeb1..022a18bd 100644 --- a/examples/aot/add_dynamic_multicore/add_double_builder.py +++ b/examples/aot/elementwise/add_dynamic_multicore/add_double_builder.py @@ -1,7 +1,7 @@ -from ptodsl import to_ir_module -import ptodsl.language as pto +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s -const = pto.const +const = s.const def meta_data(): @@ -44,18 +44,16 @@ def vec_add_1d_dynamic( cid = pto.get_block_idx() sub_bid = pto.get_subblock_idx() sub_bnum = pto.get_subblock_num() - cidmul = cid * sub_bnum - vid = cidmul + sub_bid num_blocks = pto.get_block_num() # Convert i64/i32 values to index for arithmetic ops. - vid_idx = pto.index_cast(vid) - num_cores = pto.index_cast(num_blocks) - total_elements = pto.index_cast(argN) + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast(num_blocks * sub_bnum) # number of vector cores + total_elements = s.index_cast(argN) - num_tiles_global = pto.ceil_div(total_elements, c_tile) - num_tiles_per_core = pto.ceil_div(num_tiles_global, num_cores) - tile_offset_this_core = vid_idx * num_tiles_per_core + num_tiles_global = s.ceil_div(total_elements, c_tile) + num_tiles_per_core = s.ceil_div(num_tiles_global, num_cores) + tile_offset_this_core = vid * num_tiles_per_core with pto.vector_section(): tv0 = pto.as_tensor(tensor_type, ptr=arg0, shape=[total_elements], strides=[c1]) @@ -76,34 +74,43 @@ def vec_add_1d_dynamic( need_truncate = tiles_end_this_core > num_tiles_global remaining_tiles = num_tiles_global - tile_offset_this_core - tiles_to_process = pto.select( + tiles_to_process = s.select( need_truncate, remaining_tiles, num_tiles_per_core ) elements_to_process = tiles_to_process * c_tile with pto.if_context(elements_to_process > c0): - for i in pto.for_range(c0, tiles_to_process, c1): + for i in pto.range(c0, tiles_to_process, c1): tile_offset_global = i + tile_offset_this_core offset_global = tile_offset_global * c_tile sv0 = pto.slice_view( - subtensor_type, source=tv0, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv0, + offsets=[offset_global], + sizes=[c_tile], ) sv1 = pto.slice_view( - subtensor_type, source=tv1, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv1, + offsets=[offset_global], + sizes=[c_tile], ) sv2 = pto.slice_view( - subtensor_type, source=tv2, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv2, + offsets=[offset_global], + sizes=[c_tile], ) with pto.if_context((i % c2) == c0, has_else=True) as branch: pto.load(sv0, tb0_ping) pto.load(sv1, tb1_ping) - pto.add(tb0_ping, tb1_ping, tb2_ping) + tile.add(tb0_ping, tb1_ping, tb2_ping) pto.store(tb2_ping, sv2) with branch.else_context(): pto.load(sv0, tb0_pong) pto.load(sv1, tb1_pong) - pto.add(tb0_pong, tb1_pong, tb2_pong) + tile.add(tb0_pong, tb1_pong, tb2_pong) pto.store(tb2_pong, sv2) diff --git a/examples/aot/add_dynamic_multicore/bench_add.py b/examples/aot/elementwise/add_dynamic_multicore/bench_add.py similarity index 96% rename from examples/aot/add_dynamic_multicore/bench_add.py rename to examples/aot/elementwise/add_dynamic_multicore/bench_add.py index 618cbca9..54f60dd4 100644 --- a/examples/aot/add_dynamic_multicore/bench_add.py +++ b/examples/aot/elementwise/add_dynamic_multicore/bench_add.py @@ -24,7 +24,9 @@ def add_func(x, y, z, stream_ptr=None): return add_func -def bench_add(add_func, x, y, z, kernel_name="add_func", warmup_iters=5, benchmark_iters=50): +def bench_add( + add_func, x, y, z, kernel_name="add_func", warmup_iters=5, benchmark_iters=50 +): io_bytes = x.numel() * x.element_size() * 3 # Overwrite a large buffer between launches to reduce L2 cache reuse. cache = torch.empty((256 * 1024 * 1024,), dtype=torch.int8, device=x.device) diff --git a/examples/aot/add_dynamic_multicore/caller.cpp b/examples/aot/elementwise/add_dynamic_multicore/caller.cpp similarity index 100% rename from examples/aot/add_dynamic_multicore/caller.cpp rename to examples/aot/elementwise/add_dynamic_multicore/caller.cpp diff --git a/examples/aot/add_dynamic_multicore/compile.sh b/examples/aot/elementwise/add_dynamic_multicore/compile.sh similarity index 100% rename from examples/aot/add_dynamic_multicore/compile.sh rename to examples/aot/elementwise/add_dynamic_multicore/compile.sh diff --git a/examples/aot/add_dynamic_multicore/compile_double.sh b/examples/aot/elementwise/add_dynamic_multicore/compile_double.sh similarity index 100% rename from examples/aot/add_dynamic_multicore/compile_double.sh rename to examples/aot/elementwise/add_dynamic_multicore/compile_double.sh diff --git a/examples/aot/add_dynamic_multicore/run_add.py b/examples/aot/elementwise/add_dynamic_multicore/run_add.py similarity index 80% rename from examples/aot/add_dynamic_multicore/run_add.py rename to examples/aot/elementwise/add_dynamic_multicore/run_add.py index a18dbaaa..e2a3ab53 100644 --- a/examples/aot/add_dynamic_multicore/run_add.py +++ b/examples/aot/elementwise/add_dynamic_multicore/run_add.py @@ -9,24 +9,16 @@ def torch_to_ctypes(tensor): def lib_to_func(lib): - def add_func( - x, - y, - z, - stream_ptr=None - ): + def add_func(x, y, z, stream_ptr=None): if stream_ptr is None: stream_ptr = torch.npu.current_stream()._as_parameter_ N = x.numel() lib.call_kernel( - stream_ptr, - torch_to_ctypes(x), - torch_to_ctypes(y), - torch_to_ctypes(z), - N + stream_ptr, torch_to_ctypes(x), torch_to_ctypes(y), torch_to_ctypes(z), N ) + return add_func @@ -42,7 +34,14 @@ def test_add(lib_path="./add_lib.so"): tile_size = 1024 # Keep shapes aligned to tile size, but vary tile counts so they are not # required to be multiples of `num_cores`. - tile_counts = [1, 7, num_cores - 1, num_cores + 3, 2 * num_cores + 7, 5 * num_cores - 5] + tile_counts = [ + 1, + 7, + num_cores - 1, + num_cores + 3, + 2 * num_cores + 7, + 5 * num_cores - 5, + ] shape_list = [tile_size * tiles for tiles in tile_counts] torch.manual_seed(0) @@ -60,6 +59,7 @@ def test_add(lib_path="./add_lib.so"): torch.testing.assert_close(z, z_ref) print(f"result equal for shape {shape}") + if __name__ == "__main__": test_add() test_add("./add_double_lib.so") diff --git a/examples/aot/fast_hadamard/.gitignore b/examples/aot/fast_hadamard/.gitignore new file mode 100644 index 00000000..663e5a84 --- /dev/null +++ b/examples/aot/fast_hadamard/.gitignore @@ -0,0 +1,9 @@ +hadamard_no_sync.pto +hadamard_manual_sync.pto +hadamard_auto_sync.cpp +hadamard_manual_sync.cpp +hadamard_auto_sync.pto +hadamard_auto_sync_lib.so +hadamard_manual_sync_lib.so + +perf_data* diff --git a/examples/aot/fast_hadamard/README.md b/examples/aot/fast_hadamard/README.md new file mode 100644 index 00000000..6b19cee9 --- /dev/null +++ b/examples/aot/fast_hadamard/README.md @@ -0,0 +1,8 @@ +Usage: + +```bash +bash ./compile.sh # generate PTO/CPP and build both auto/manual sync libs +python ./run_hadamard.py # test auto-sync lib (default) +python ./run_hadamard.py --manual-sync # test manual-sync lib +python ./plot_perf.py # optionally visualization +``` diff --git a/examples/aot/fast_hadamard/caller.cpp b/examples/aot/fast_hadamard/caller.cpp new file mode 100644 index 00000000..1ddaff6a --- /dev/null +++ b/examples/aot/fast_hadamard/caller.cpp @@ -0,0 +1,28 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "hadamard_auto_sync.cpp" +#endif +#include KERNEL_CPP + +#ifndef KERNEL_FN +#define KERNEL_FN fast_hadamard_autosync +#endif + +#ifndef NUM_CORES +#define NUM_CORES 24 +#endif + +extern "C" void call_kernel( + uint32_t blockDim, + void *stream, + uint8_t *x, + uint32_t batch, + uint32_t n, + uint32_t log2_n) +{ + uint32_t launch_blocks = blockDim > 0 ? blockDim : NUM_CORES; + KERNEL_FN<<>>( + reinterpret_cast(x), + static_cast(batch), + static_cast(n), + static_cast(log2_n)); +} diff --git a/examples/aot/fast_hadamard/compile.sh b/examples/aot/fast_hadamard/compile.sh new file mode 100644 index 00000000..a95f6148 --- /dev/null +++ b/examples/aot/fast_hadamard/compile.sh @@ -0,0 +1,46 @@ +set -e + +rm -f \ + hadamard_auto_sync.pto hadamard_manual_sync.pto \ + hadamard_auto_sync.cpp hadamard_manual_sync.cpp \ + hadamard_auto_sync_lib.so hadamard_manual_sync_lib.so + +# Auto-sync path: rely on ptoas synchronization insertion. +python ./hadamard_builder.py > ./hadamard_auto_sync.pto +ptoas --enable-insert-sync ./hadamard_auto_sync.pto -o ./hadamard_auto_sync.cpp + +# Manual-sync path: explicit record/wait events from builder. +python ./hadamard_builder.py --manual-sync > ./hadamard_manual_sync.pto +ptoas ./hadamard_manual_sync.pto -o ./hadamard_manual_sync.cpp + +bisheng \ + -I${ASCEND_TOOLKIT_HOME}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + ./caller.cpp \ + -o ./hadamard_auto_sync_lib.so + +bisheng \ + -I${ASCEND_TOOLKIT_HOME}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + -DKERNEL_CPP="\"hadamard_manual_sync.cpp\"" \ + -DKERNEL_FN=fast_hadamard_manualsync \ + ./caller.cpp \ + -o ./hadamard_manual_sync_lib.so diff --git a/examples/aot/fast_hadamard/hadamard_builder.py b/examples/aot/fast_hadamard/hadamard_builder.py new file mode 100644 index 00000000..6d641a26 --- /dev/null +++ b/examples/aot/fast_hadamard/hadamard_builder.py @@ -0,0 +1,286 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +ELEMENTS_PER_TILE = 32 * 1024 // 2 # 32KB UB / sizeof(fp16) +HALF_ELEMENTS_PER_TILE = ELEMENTS_PER_TILE // 2 + + +def meta_data(): + dtype = pto.float16 + ptr_type = pto.PtrType(dtype) + index_dtype = pto.int32 + + tensor_type = pto.TensorType(rank=1, dtype=dtype) + subtensor_full = pto.SubTensorType(shape=[1, ELEMENTS_PER_TILE], dtype=dtype) + subtensor_half = pto.SubTensorType(shape=[1, HALF_ELEMENTS_PER_TILE], dtype=dtype) + + tile_cfg = pto.TileBufConfig() + tile_full = pto.TileBufType( + shape=[1, ELEMENTS_PER_TILE], + valid_shape=[1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + tile_half = pto.TileBufType( + shape=[1, HALF_ELEMENTS_PER_TILE], + valid_shape=[1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + + return { + "ptr_type": ptr_type, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_full": subtensor_full, + "subtensor_half": subtensor_half, + "tile_full": tile_full, + "tile_half": tile_half, + } + + +@to_ir_module(meta_data=meta_data) +def fast_hadamard_autosync( + x_ptr: "ptr_type", + batch_i32: "index_dtype", + n_i32: "index_dtype", + log2_n_i32: "index_dtype", +) -> None: + c0 = const(0) + c1 = const(1) + c2 = const(2) + + batch = s.index_cast(batch_i32) + n = s.index_cast(n_i32) + log2_n = s.index_cast(log2_n_i32) + + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast(num_blocks * sub_bnum) # number of vector cores + + with pto.vector_section(): + samples_per_core = s.ceil_div(batch, num_cores) + sample_offset = vid * samples_per_core + + with pto.if_context(sample_offset < batch): + samples_end = sample_offset + samples_per_core + samples_to_process = s.select( + samples_end > batch, + batch - sample_offset, + samples_per_core, + ) + + with pto.if_context(samples_to_process > c0): + total_elements = batch * n + tv_x = pto.as_tensor( + tensor_type, ptr=x_ptr, shape=[total_elements], strides=[c1] + ) + + # Two independent tile sets (ping/pong) so event_id 0/1 map to + # disjoint UB buffers, matching the manual C++ reference. + tb_row_0 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + + tb_row_1 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + + n_half = n // c2 + + # Keep one sample per chunk. Multi-sample chunks interact + # poorly with static tile subset sizing in current PTO Python + # bindings and can corrupt rows for larger batches. + samples_per_load = c1 + num_chunks = s.ceil_div(samples_to_process, samples_per_load) + + def process_rows(tb_row, tb_even, tb_odd, gm_offset, cur_samples): + for s in pto.range(c0, cur_samples, c1): + row_offset = gm_offset + s * n + sv_row = pto.slice_view( + subtensor_full, source=tv_x, offsets=[row_offset], sizes=[n] + ) + # Alias row halves inside UB row tile (no GM round-trip + # per Hadamard iteration). + tb_first = tile.subset( + tb_row, [c0, c0], [1, HALF_ELEMENTS_PER_TILE] + ) + tb_second = tile.subset( + tb_row, [c0, n_half], [1, HALF_ELEMENTS_PER_TILE] + ) + + pto.load(sv_row, tb_row) + for _ in pto.range(c0, log2_n, c1): + tile.gather(tb_row, tb_even, mask_pattern="P0101") + tile.gather(tb_row, tb_odd, mask_pattern="P1010") + tile.add(tb_even, tb_odd, tb_first) + tile.sub(tb_even, tb_odd, tb_second) + pto.store(tb_row, sv_row) + + for chunk_i in pto.range(c0, num_chunks, c1): + sample_done = chunk_i * samples_per_load + chunk_left = samples_to_process - sample_done + cur_samples = s.select( + chunk_left < samples_per_load, chunk_left, samples_per_load + ) + + with pto.if_context(cur_samples > c0): + gm_offset = (sample_offset + sample_done) * n + use_ev0 = (chunk_i % c2) == c0 + + with pto.if_context(use_ev0, has_else=True) as branch: + process_rows( + tb_row_0, tb_even_0, tb_odd_0, gm_offset, cur_samples + ) + with branch.else_context(): + process_rows( + tb_row_1, tb_even_1, tb_odd_1, gm_offset, cur_samples + ) + + +@to_ir_module(meta_data=meta_data) +def fast_hadamard_manualsync( + x_ptr: "ptr_type", + batch_i32: "index_dtype", + n_i32: "index_dtype", + log2_n_i32: "index_dtype", +) -> None: + c0 = const(0) + c1 = const(1) + c2 = const(2) + + batch = s.index_cast(batch_i32) + n = s.index_cast(n_i32) + log2_n = s.index_cast(log2_n_i32) + + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + + vid = s.index_cast(cid * sub_bnum + sub_bid) # vector core index + num_cores = s.index_cast(num_blocks * sub_bnum) # number of vector cores + + with pto.vector_section(): + samples_per_core = s.ceil_div(batch, num_cores) + sample_offset = vid * samples_per_core + + with pto.if_context(sample_offset < batch): + samples_end = sample_offset + samples_per_core + samples_to_process = s.select( + samples_end > batch, + batch - sample_offset, + samples_per_core, + ) + + with pto.if_context(samples_to_process > c0): + total_elements = batch * n + tv_x = pto.as_tensor( + tensor_type, ptr=x_ptr, shape=[total_elements], strides=[c1] + ) + + # Two independent tile sets (ping/pong) so event_id 0/1 map to + # disjoint UB buffers, matching the manual C++ reference. + tb_row_0 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_0 = pto.alloc_tile(tile_half, valid_col=n // c2) + + tb_row_1 = pto.alloc_tile(tile_full, valid_col=n) + tb_even_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + tb_odd_1 = pto.alloc_tile(tile_half, valid_col=n // c2) + + n_half = n // c2 + + # Keep one sample per chunk. Multi-sample chunks interact + # poorly with static tile subset sizing in current PTO Python + # bindings and can corrupt rows for larger batches. + samples_per_load = c1 + num_chunks = s.ceil_div(samples_to_process, samples_per_load) + + def process_rows( + tb_row, tb_even, tb_odd, event_id, gm_offset, cur_samples + ): + for s in pto.range(c0, cur_samples, c1): + row_offset = gm_offset + s * n + sv_row = pto.slice_view( + subtensor_full, source=tv_x, offsets=[row_offset], sizes=[n] + ) + # Alias row halves inside UB row tile (no GM round-trip + # per Hadamard iteration). + tb_first = tile.subset( + tb_row, [c0, c0], [1, HALF_ELEMENTS_PER_TILE] + ) + tb_second = tile.subset( + tb_row, [c0, n_half], [1, HALF_ELEMENTS_PER_TILE] + ) + + pto.wait_event("VEC", "LOAD", event_id=event_id) + pto.wait_event("STORE_VEC", "VEC", event_id=event_id) + pto.load(sv_row, tb_row) + pto.record_wait_pair("LOAD", "VEC", event_id=event_id) + + for _ in pto.range(c0, log2_n, c1): + tile.gather(tb_row, tb_even, mask_pattern="P0101") + tile.gather(tb_row, tb_odd, mask_pattern="P1010") + pto.barrier("VEC") + tile.add(tb_even, tb_odd, tb_first) + tile.sub(tb_even, tb_odd, tb_second) + pto.barrier("VEC") + + pto.record_wait_pair("VEC", "STORE_VEC", event_id=event_id) + pto.store(tb_row, sv_row) + pto.record_event("STORE_VEC", "VEC", event_id=event_id) + pto.record_event("VEC", "LOAD", event_id=event_id) + + for event_id in (0, 1): + pto.record_event("VEC", "LOAD", event_id=event_id) + pto.record_event("STORE_VEC", "VEC", event_id=event_id) + + for chunk_i in pto.range(c0, num_chunks, c1): + sample_done = chunk_i * samples_per_load + chunk_left = samples_to_process - sample_done + cur_samples = s.select( + chunk_left < samples_per_load, chunk_left, samples_per_load + ) + + with pto.if_context(cur_samples > c0): + gm_offset = (sample_offset + sample_done) * n + use_ev0 = (chunk_i % c2) == c0 + + with pto.if_context(use_ev0, has_else=True) as branch: + process_rows( + tb_row_0, tb_even_0, tb_odd_0, 0, gm_offset, cur_samples + ) + with branch.else_context(): + process_rows( + tb_row_1, tb_even_1, tb_odd_1, 1, gm_offset, cur_samples + ) + + for event_id in (0, 1): + pto.wait_event("VEC", "LOAD", event_id=event_id) + pto.wait_event("STORE_VEC", "VEC", event_id=event_id) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--manual-sync", + action="store_true", + help="Emit explicit record/wait events instead of relying on --enable-insert-sync.", + ) + args = parser.parse_args() + if args.manual_sync: + module = fast_hadamard_manualsync + else: + module = fast_hadamard_autosync + print(module) diff --git a/examples/aot/fast_hadamard/plot_perf.py b/examples/aot/fast_hadamard/plot_perf.py new file mode 100644 index 00000000..3b66bc50 --- /dev/null +++ b/examples/aot/fast_hadamard/plot_perf.py @@ -0,0 +1,73 @@ +import os +import csv + +try: + import matplotlib.pyplot as plt +except ImportError: + plt = None + + +def plot_bandwidth(input_dir="./perf_data/", output_path="bw_vs_shape.png"): + """Generate bandwidth plot from benchmark CSVs.""" + if plt is None: + print("Warning: matplotlib is not installed; skipping plot generation.") + return + + BENCH_BATCHES = [1, 5, 8, 10, 16, 20, 32, 40, 64, 128, 256, 512, 1024] + BENCH_BLOCK_DIMS = [20, 24] + + fig, axes = plt.subplots(1, len(BENCH_BLOCK_DIMS), figsize=(14, 6), sharey=True) + if len(BENCH_BLOCK_DIMS) == 1: + axes = [axes] + + for ax, block_dim in zip(axes, BENCH_BLOCK_DIMS): + csv_path = os.path.join(input_dir, f"fht_pto_bd{block_dim}.csv") + if not os.path.exists(csv_path): + ax.set_title(f"BLOCK_DIM={block_dim} (no data)") + continue + + # Parse CSV: hidden_dim -> {batch: bw} + data = {} + with open(csv_path, encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + batch = int(row["batch"]) + n = int(row["N"]) + bw = float(row["bandwidth_gbs"]) + data.setdefault(n, {})[batch] = bw + + for idx, hidden_dim in enumerate(sorted(data.keys())): + batches = sorted(data[hidden_dim].keys()) + bws = [data[hidden_dim][b] for b in batches] + + if idx < 10: + marker = "o" + else: + last_markers = ["s", "^", "D"] + marker = last_markers[idx - 10] + + ax.plot( + batches, + bws, + marker=marker, + markersize=4, + label=f"hidden_dim={hidden_dim}", + ) + + ax.set_xscale("log", base=2) + ax.set_xticks(BENCH_BATCHES) + ax.set_xticklabels([str(b) for b in BENCH_BATCHES], rotation=45, fontsize=7) + ax.set_xlabel("batch") + ax.set_title(f"BLOCK_DIM={block_dim}") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=7, ncol=2) + + axes[0].set_ylabel("Bandwidth (GB/s)") + fig.suptitle("Fast Hadamard PTO-DSL: Bandwidth vs Shape") + fig.tight_layout() + fig.savefig(input_dir + output_path, dpi=150) + print(f"\nPlot saved to {input_dir+output_path}") + + +if __name__ == "__main__": + plot_bandwidth() diff --git a/examples/aot/fast_hadamard/run_hadamard.py b/examples/aot/fast_hadamard/run_hadamard.py new file mode 100644 index 00000000..60ee8aee --- /dev/null +++ b/examples/aot/fast_hadamard/run_hadamard.py @@ -0,0 +1,208 @@ +import os +import argparse +import ctypes +import csv +import math + +import torch +import torch_npu # noqa: F401 + +from ptodsl.test_util import get_test_device + +ELEMENTS_PER_TILE = 32 * 1024 // 2 # 32KB UB / sizeof(fp16) + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path, block_dim=24): + lib = ctypes.CDLL(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, # blockDim + ctypes.c_void_p, # stream + ctypes.c_void_p, # x (in-place) + ctypes.c_uint32, # batch + ctypes.c_uint32, # n + ctypes.c_uint32, # log2_n + ] + lib.call_kernel.restype = None + + def hadamard_func(x, batch, n, log2_n, block_dim=block_dim, stream_ptr=None): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + assert n <= ELEMENTS_PER_TILE, f"n must be <= {ELEMENTS_PER_TILE}, got {n}" + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(x), + batch, + n, + log2_n, + ) + + return hadamard_func + + +def hadamard_ref_inplace(x): + """Reference FHT matching TGATHER(P0101/P1010) + TADD/TSUB layout.""" + x = x.clone() + n = x.shape[-1] + n_half = n // 2 + log2_n = int(math.log2(n)) + for _ in range(log2_n): + even = x[..., 0::2].clone() + odd = x[..., 1::2].clone() + x[..., :n_half] = even + odd + x[..., n_half:] = even - odd + return x + + +def _is_power_of_two(v): + return v > 0 and (v & (v - 1)) == 0 + + +def test_hadamard(hadamard_func, block_dim=24): + torch.manual_seed(0) + dtype = torch.float16 + batch_list = [1, 7, 29, 65] + n_list = [128, 256, 512, 1024, 2048, 4096, 8192, 16384] + + results = [] + for batch in batch_list: + for n in n_list: + if not _is_power_of_two(n): + continue + log2_n = int(math.log2(n)) + x = torch.randn(batch, n, device=device, dtype=dtype) + y_ref = hadamard_ref_inplace(x) + + hadamard_func(x, batch, n, log2_n) + torch.npu.synchronize() + + is_match = True + detail = "" + try: + torch.testing.assert_close(x, y_ref) + except AssertionError as err: + is_match = False + detail = str(err).strip() if str(err) else "assert_close failed" + + status = "match" if is_match else "mismatch" + print(f"[{status}] batch={batch}, n={n}, lib={lib_path}") + if detail: + print(" detail:") + print(detail) + results.append((batch, n, status, detail)) + + print(f"detailed summary for {lib_path}:") + for batch, n, status, detail in results: + msg = f" batch={batch}, n={n}, status={status}" + print(msg) + if detail: + print(" detail:") + print(detail) + return results + + +def benchmark(hadamard_func, warmup=2, repeats=20, output_dir="./perf_data/"): + """Benchmark across (batch, N, block_dim) configs. + + Uses separate input tensors per run to avoid L2 cache reuse, + and a single timing-event pair averaged over all runs. + """ + TEST_HIDDEN_DIMS = [128, 256, 512, 1024, 2048, 4096, 8192, 16384] + BENCH_BATCHES = [1, 5, 8, 10, 16, 20, 32, 40, 64, 128, 256, 512, 1024] + BENCH_BLOCK_DIMS = [20, 24] + + os.makedirs(output_dir, exist_ok=True) + + for block_dim in BENCH_BLOCK_DIMS: + print(f"\n{'=' * 60}") + print(f"BENCHMARK (BLOCK_DIM={block_dim})") + print(f"{'=' * 60}") + header = ( + f"{'batch':>6s} {'N':>6s}" + f" {'duration_us':>12s} {'bandwidth_gbs':>14s}" + ) + print(header) + print("-" * len(header)) + + records = [] + + for batch in BENCH_BATCHES: + for n in TEST_HIDDEN_DIMS: + log2_n = int(math.log2(n)) + allocated = warmup + repeats + + # Separate GM tensors to avoid L2 cache reuse + x_list = [ + torch.randn(batch, n, device="npu", dtype=torch.float16) + for _ in range(allocated) + ] + + # Warmup + for i in range(warmup): + hadamard_func(x_list[i], batch, n, log2_n, block_dim=block_dim) + torch.npu.synchronize() + + # Timed runs — single event pair, average over repeats + start = torch.npu.Event(enable_timing=True) + end = torch.npu.Event(enable_timing=True) + + start.record() + for i in range(repeats): + hadamard_func( + x_list[warmup + i], + batch, + n, + log2_n, + block_dim=block_dim, + ) + end.record() + torch.npu.synchronize() + + duration_ms = start.elapsed_time(end) / repeats + dur_us = duration_ms * 1e3 + + # Bandwidth: read + write = 2 * batch * n * sizeof(half) + data_bytes = 2 * batch * n * 2 + bw_gbs = (data_bytes / 1e9) / (dur_us / 1e6) if dur_us > 0 else 0.0 + + print(f"{batch:>6d} {n:>6d}" f" {dur_us:>12.2f} {bw_gbs:>14.2f}") + records.append(f"{batch},{n},{dur_us:.4f},{bw_gbs:.4f}") + + csv_path = os.path.join(output_dir, f"fht_pto_bd{block_dim}.csv") + with open(csv_path, "w", encoding="utf-8") as f: + f.write("batch,N,duration_us,bandwidth_gbs\n") + f.write("\n".join(records) + "\n") + print(f"\nSaved to {csv_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--manual-sync", + action="store_true", + help="Use manual-sync library instead of the default auto-sync library.", + ) + parser.add_argument( + "--block-dim", + type=int, + default=24, + help="Kernel blockDim (default: 24).", + ) + args = parser.parse_args() + + lib_path = ( + "./hadamard_manual_sync_lib.so" + if args.manual_sync + else "./hadamard_auto_sync_lib.so" + ) + + device = get_test_device() + torch.npu.set_device(device) + hadamard_func = load_lib(lib_path=lib_path, block_dim=args.block_dim) + + test_hadamard(hadamard_func) + benchmark(hadamard_func) diff --git a/examples/aot/fast_inverse/basic_dense/.gitignore b/examples/aot/fast_inverse/basic_dense/.gitignore new file mode 100644 index 00000000..e33609d2 --- /dev/null +++ b/examples/aot/fast_inverse/basic_dense/.gitignore @@ -0,0 +1 @@ +*.png diff --git a/examples/aot/fast_inverse/basic_dense/README.md b/examples/aot/fast_inverse/basic_dense/README.md new file mode 100644 index 00000000..acc5a93c --- /dev/null +++ b/examples/aot/fast_inverse/basic_dense/README.md @@ -0,0 +1,27 @@ +```bash +bash compile.sh 64 # build -> inverse_lib.so + +# Validate correctness +python run_inverse.py --matrix-size 64 + +# Another matrix size +python run_inverse.py --matrix-size 128 + +# Measure effective bandwidth +python bench_inverse.py --matrix-size 64 --out-png bench_inverse_bandwidth.png +``` + +`bench_inverse.py` reports and plots bandwidth using only: +- read of `in_delta` (`torch_to_ctypes(in_delta)`) +- write of `out` (`torch_to_ctypes(out)`) + +Timing measures only the kernel launch (`lib.call_kernel(...)`) and excludes tensor +preparation (`identity`, `in_delta`, `identity_neg`, `out` creation). + +This dense demo uses input shape `[batch, n, n]` and applies the same fast-inverse recurrence +as the block-diagonal example, with `log2_blocksize = log2(n)` (no extra diagonal block size). +It uses persistent-kernel style launch with fixed `blockDim=24`, and each core loops over +its assigned batch indices at runtime. + +For numerical stability in this educational demo, test inputs are generated as: +`M = I + scale * random`, and the kernel computes `inv(M)` via `A = M - I`. diff --git a/examples/aot/fast_inverse/basic_dense/bench_inverse.py b/examples/aot/fast_inverse/basic_dense/bench_inverse.py new file mode 100644 index 00000000..893f7ddf --- /dev/null +++ b/examples/aot/fast_inverse/basic_dense/bench_inverse.py @@ -0,0 +1,200 @@ +import argparse +import ctypes +import math +import random + +import numpy as np +import torch +import torch_npu # noqa: F401 + +try: + import matplotlib.pyplot as plt +except ImportError: + plt = None + +from ptodsl import do_bench +from ptodsl.test_util import get_test_device + +random.seed(42) +torch.manual_seed(42) +np.random.seed(42) + +SUPPORTED_MATRIX_SIZES = (16, 32, 64, 128) +DEFAULT_BATCH_SIZES = [2**k for k in range(4, 16)] # 16, 32, ..., 32768 +try: + PERSISTENT_BLOCK_DIM = int(torch.npu.get_device_properties("npu").cube_core_num) +except Exception: + PERSISTENT_BLOCK_DIM = 24 + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def _dtype_nbytes(dtype: torch.dtype) -> int: + return torch.empty((), dtype=dtype).element_size() + + +def inverse_io_bytes(in_delta: torch.Tensor, out: torch.Tensor) -> int: + # Requested traffic model: read in_delta + write out only. + return in_delta.numel() * _dtype_nbytes( + in_delta.dtype + ) + out.numel() * _dtype_nbytes(out.dtype) + + +def load_lib(lib_path): + lib = ctypes.CDLL(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, # blockDim (fixed core count) + ctypes.c_void_p, # stream + ctypes.c_void_p, # out + ctypes.c_void_p, # in_delta (M - I) + ctypes.c_void_p, # identity_neg + ctypes.c_uint32, # runtime batch_size + ctypes.c_uint32, # log2(matrix_size) + ] + lib.call_kernel.restype = None + return lib + + +def dense_stable_matrix(n, batch, scale=0.02): + eye = np.eye(n, dtype=np.float32) + noise = np.random.uniform(-1.0, 1.0, size=(batch, n, n)).astype(np.float32) + out = eye[None, :, :] + scale * noise + return torch.from_numpy(out) + + +def benchmark_kernel_seconds(kernel_launch_fn, warmup: int, iters: int) -> float: + # Measure kernel launch only (preparation is done outside this function). + return do_bench( + kernel_launch_fn, warmup_iters=warmup, benchmark_iters=iters, unit="s" + ) + + +def run_benchmark( + lib, + *, + label: str, + matrix_size: int, + batch_sizes: list[int], + warmup: int, + iters: int, +): + log2_blocksize = int(math.log2(matrix_size)) + stream_ptr = torch.npu.current_stream()._as_parameter_ + bandwidth_gib_s = [] + + print(f"\n=== benchmark {label} ===") + for batch in batch_sizes: + inp = dense_stable_matrix(n=matrix_size, batch=batch).to(device) + inp_fp16 = inp.to(torch.float16).contiguous() + + # Preparation work excluded from benchmark timing. + identity = torch.eye(matrix_size, dtype=torch.float16, device=device) + in_delta = (inp_fp16 - identity).contiguous() + identity_neg = (-identity).contiguous() + out = torch.zeros_like(inp_fp16, dtype=torch.float32, device=device) + + def launch_only(): + lib.call_kernel( + PERSISTENT_BLOCK_DIM, + stream_ptr, + torch_to_ctypes(out), + torch_to_ctypes(in_delta), + torch_to_ctypes(identity_neg), + batch, + log2_blocksize, + ) + + avg_s = benchmark_kernel_seconds(launch_only, warmup=warmup, iters=iters) + io_bytes = inverse_io_bytes(in_delta, out) + total_traffic_gib = io_bytes / (1024**3) + gib_s = io_bytes / avg_s / (1024**3) + bandwidth_gib_s.append(gib_s) + print( + f"{label:>6s} | batch={batch:5d} | {avg_s * 1e3:.3f} ms | " + f"{gib_s:.2f} GiB/s | traffic={total_traffic_gib:.4f} GiB" + ) + + return bandwidth_gib_s + + +def plot_results( + batch_sizes: list[int], + bw_gib_s: list[float], + out_png: str, + n: int, +) -> None: + if plt is None: + print("Warning: matplotlib is not installed; skipping plot generation.") + return + + plt.figure(figsize=(8, 5)) + plt.plot(batch_sizes, bw_gib_s, "o-", label="kernel") + plt.xlabel("Batch size") + plt.ylabel("Bandwidth (GiB/s)") + plt.title(f"Fast Inverse Bandwidth (n={n})") + plt.xscale("log", base=2) + plt.xticks(batch_sizes, [str(x) for x in batch_sizes]) + plt.grid(True, linestyle="--", alpha=0.6) + plt.legend() + plt.tight_layout() + plt.savefig(out_png) + print(f"Saved plot to {out_png}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--matrix-size", + type=int, + choices=SUPPORTED_MATRIX_SIZES, + default=64, + help="Dense matrix size n.", + ) + parser.add_argument( + "--batch-sizes", + type=int, + nargs="+", + default=DEFAULT_BATCH_SIZES, + help="Batch sizes to benchmark.", + ) + parser.add_argument( + "--warmup", + type=int, + default=20, + help="Number of warmup iterations for each batch size.", + ) + parser.add_argument( + "--iters", + type=int, + default=50, + help="Number of measured iterations for each batch size.", + ) + parser.add_argument( + "--lib-path", + type=str, + default="./inverse_lib.so", + help="Shared library path produced by compile.sh.", + ) + parser.add_argument( + "--out-png", + type=str, + default="bench_inverse_bandwidth.png", + help="Output image path for the benchmark figure.", + ) + args = parser.parse_args() + + device = get_test_device() + torch.npu.set_device(device) + + lib = load_lib(args.lib_path) + bw = run_benchmark( + lib, + label="kernel", + matrix_size=args.matrix_size, + batch_sizes=args.batch_sizes, + warmup=args.warmup, + iters=args.iters, + ) + plot_results(args.batch_sizes, bw, args.out_png, args.matrix_size) diff --git a/examples/aot/fast_inverse/basic_dense/caller.cpp b/examples/aot/fast_inverse/basic_dense/caller.cpp new file mode 100644 index 00000000..131846e6 --- /dev/null +++ b/examples/aot/fast_inverse/basic_dense/caller.cpp @@ -0,0 +1,26 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "inverse.cpp" +#endif + +#ifndef KERNEL_FN +#define KERNEL_FN tri_inv_trick_fp16 +#endif + +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, + void *stream, + uint8_t *tensor_out, + uint8_t *tensor_in, + uint8_t *identity_in, + uint32_t runtime_batch_size, + uint32_t log2_blocksize) +{ + KERNEL_FN<<>>( + reinterpret_cast(tensor_out), + reinterpret_cast(tensor_in), + reinterpret_cast(identity_in), + static_cast(runtime_batch_size), + static_cast(log2_blocksize)); +} diff --git a/examples/aot/fast_inverse/basic_dense/compile.sh b/examples/aot/fast_inverse/basic_dense/compile.sh new file mode 100644 index 00000000..94619904 --- /dev/null +++ b/examples/aot/fast_inverse/basic_dense/compile.sh @@ -0,0 +1,34 @@ +#!/usr/bin/env bash +set -euo pipefail + +ARTIFACT_DIR="./build_artifacts" +MATRIX_SIZE="${1:-64}" +if [[ $# -gt 1 ]]; then + echo "Usage: bash compile.sh [matrix_size]" + exit 1 +fi + +mkdir -p "${ARTIFACT_DIR}" +rm -f "${ARTIFACT_DIR}/inverse.pto" "${ARTIFACT_DIR}/inverse.cpp" "inverse_lib.so" + +python ./inverse_builder.py --matrix-size "${MATRIX_SIZE}" > "${ARTIFACT_DIR}/inverse.pto" +ptoas --enable-insert-sync "${ARTIFACT_DIR}/inverse.pto" -o "${ARTIFACT_DIR}/inverse.cpp" + +PTO_LIB_PATH=/sources/pto-isa +# PTO_LIB_PATH=$ASCEND_TOOLKIT_HOME + +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + -DKERNEL_CPP="\"${ARTIFACT_DIR}/inverse.cpp\"" \ + ./caller.cpp \ + -o "./inverse_lib.so" diff --git a/examples/aot/fast_inverse/basic_dense/inverse_builder.py b/examples/aot/fast_inverse/basic_dense/inverse_builder.py new file mode 100644 index 00000000..498abdc5 --- /dev/null +++ b/examples/aot/fast_inverse/basic_dense/inverse_builder.py @@ -0,0 +1,174 @@ +# pyright: reportUndefinedVariable=false +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const +SUPPORTED_MATRIX_SIZES = (16, 32, 64, 128) + + +def make_meta_data(n: int): + def meta_data(): + in_dtype = pto.float16 + out_dtype = pto.float32 + i32 = pto.int32 + + in_ptr_type = pto.PtrType(in_dtype) + out_ptr_type = pto.PtrType(out_dtype) + in_tensor_type = pto.TensorType(rank=2, dtype=in_dtype) + out_tensor_type = pto.TensorType(rank=2, dtype=out_dtype) + in_subtensor = pto.SubTensorType(shape=[n, n], dtype=in_dtype) + out_subtensor = pto.SubTensorType(shape=[n, n], dtype=out_dtype) + l1_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=in_dtype, memory_space="MAT" + ) + l0a_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=in_dtype, memory_space="LEFT" + ) + l0b_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=in_dtype, memory_space="RIGHT" + ) + l0c_tile_type = pto.TileBufType( + shape=[n, n], valid_shape=[n, n], dtype=out_dtype, memory_space="ACC" + ) + + return { + "in_ptr_type": in_ptr_type, + "out_ptr_type": out_ptr_type, + "i32": i32, + "in_tensor_type": in_tensor_type, + "out_tensor_type": out_tensor_type, + "in_subtensor": in_subtensor, + "out_subtensor": out_subtensor, + "l1_tile_type": l1_tile_type, + "l0a_tile_type": l0a_tile_type, + "l0b_tile_type": l0b_tile_type, + "l0c_tile_type": l0c_tile_type, + } + + return meta_data + + +def build_kernel(matrix_size: int): + @to_ir_module(meta_data=make_meta_data(matrix_size)) + def tri_inv_trick_fp16( + out_ptr: "out_ptr_type", + in_ptr: "in_ptr_type", + i_neg_ptr: "in_ptr_type", + matrix_size_i32: "i32", + log2_blocksize_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + n_c = const(matrix_size) + + batch_size = s.index_cast(matrix_size_i32) + log2_blocksize = s.index_cast(log2_blocksize_i32) + block_idx = s.index_cast(pto.get_block_idx()) + num_cores = s.index_cast(pto.get_block_num()) + total_rows = batch_size * n_c + + # Persistent-kernel work split: base + remainder. + base = batch_size // num_cores + rem = batch_size % num_cores + lt_rem = s.lt(block_idx, rem) + min_bid_rem = s.min_u(block_idx, rem) + b_start = block_idx * base + min_bid_rem + length = base + s.select(lt_rem, c1, c0) + b_end = s.min_u(b_start + length, batch_size) + + tv_m = pto.as_tensor( + in_tensor_type, ptr=in_ptr, shape=[total_rows, n_c], strides=[n_c, c1] + ) + tv_out = pto.as_tensor( + out_tensor_type, ptr=out_ptr, shape=[total_rows, n_c], strides=[n_c, c1] + ) + tv_i_neg = pto.as_tensor( + in_tensor_type, ptr=i_neg_ptr, shape=[n_c, n_c], strides=[n_c, c1] + ) + + sv_i_neg = pto.slice_view( + in_subtensor, source=tv_i_neg, offsets=[c0, c0], sizes=[n_c, n_c] + ) + + i_neg_l1 = pto.alloc_tile(l1_tile_type) + x_l1 = pto.alloc_tile(l1_tile_type) + y_l1 = pto.alloc_tile(l1_tile_type) + i_l1 = pto.alloc_tile(l1_tile_type) + a_l0 = pto.alloc_tile(l0a_tile_type) + b_l0 = pto.alloc_tile(l0b_tile_type) + c_l0 = pto.alloc_tile(l0c_tile_type) + + pto.load(sv_i_neg, i_neg_l1) + # I = (-I) @ (-I) is batch-invariant, so compute it once. + tile.mov(i_neg_l1, a_l0) + tile.mov(i_neg_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, i_l1) + + for b_idx in pto.range(b_start, b_end, c1): + row_offset = b_idx * n_c + sv_m = pto.slice_view( + in_subtensor, + source=tv_m, + offsets=[row_offset, c0], + sizes=[n_c, n_c], + ) + sv_out = pto.slice_view( + out_subtensor, + source=tv_out, + offsets=[row_offset, c0], + sizes=[n_c, n_c], + ) + + # in_ptr carries A = M - I, where M is the dense matrix to invert. + pto.load(sv_m, y_l1) + + tile.mov(y_l1, a_l0) + tile.mov(y_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, y_l1) # y = A @ A + + tile.mov(i_neg_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) # c = -A + + tile.mov(i_neg_l1, a_l0) + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) # c = I - A + tile.mov(c_l0, x_l1) # x = I - A + + # Mirrors: + # for i in range(log2_c - 1): + # X, Y = (X + X @ Y, Y @ Y) + for iter_idx in pto.range(c0, log2_blocksize, c1): + tile.mov(x_l1, a_l0) + tile.mov(i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + + tile.mov(y_l1, b_l0) + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) # x + x @ y + + with pto.if_context(iter_idx + c1 < log2_blocksize): + tile.mov(c_l0, x_l1) + tile.mov(y_l1, a_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, y_l1) # y = y @ y + + pto.store(c_l0, sv_out) + + return tri_inv_trick_fp16 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--matrix-size", + type=int, + choices=SUPPORTED_MATRIX_SIZES, + default=64, + help="Compile-time specialized dense matrix size.", + ) + args = parser.parse_args() + module = build_kernel(args.matrix_size) + print(module) diff --git a/examples/aot/fast_inverse/basic_dense/run_inverse.py b/examples/aot/fast_inverse/basic_dense/run_inverse.py new file mode 100644 index 00000000..13b51ff6 --- /dev/null +++ b/examples/aot/fast_inverse/basic_dense/run_inverse.py @@ -0,0 +1,155 @@ +import argparse +import ctypes +import math +import random +import warnings + +import numpy as np +import torch +import torch_npu # noqa: F401 + +from ptodsl.test_util import get_test_device + +random.seed(42) +torch.manual_seed(42) +np.random.seed(42) + +SUPPORTED_MATRIX_SIZES = (16, 32, 64, 128) +try: + PERSISTENT_BLOCK_DIM = int(torch.npu.get_device_properties("npu").cube_core_num) +except Exception: + PERSISTENT_BLOCK_DIM = 24 + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path): + lib = ctypes.CDLL(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, # blockDim (fixed core count) + ctypes.c_void_p, # stream + ctypes.c_void_p, # out + ctypes.c_void_p, # in_delta (M - I) + ctypes.c_void_p, # identity_neg + ctypes.c_uint32, # runtime batch_size + ctypes.c_uint32, # log2(matrix_size) + ] + lib.call_kernel.restype = None + return lib + + +def dense_stable_matrix(n, batch, scale=0.02): + eye = np.eye(n, dtype=np.float32) + noise = np.random.uniform(-1.0, 1.0, size=(batch, n, n)).astype(np.float32) + out = eye[None, :, :] + scale * noise + return torch.from_numpy(out) + + +def run_kernel(lib, inp): + inp_fp16 = inp.to(torch.float16).contiguous() + n = int(inp_fp16.shape[-1]) + batch = int(inp_fp16.shape[0]) + log2_blocksize = int(math.log2(n)) + + identity = torch.eye(n, dtype=torch.float16, device=inp_fp16.device) + in_delta = (inp_fp16 - identity).contiguous() + identity_neg = (-identity).contiguous() + out = torch.zeros_like(inp_fp16, dtype=torch.float32, device=inp_fp16.device) + + stream_ptr = torch.npu.current_stream()._as_parameter_ + lib.call_kernel( + PERSISTENT_BLOCK_DIM, + stream_ptr, + torch_to_ctypes(out), + torch_to_ctypes(in_delta), + torch_to_ctypes(identity_neg), + batch, + log2_blocksize, + ) + torch.npu.synchronize() + return out + + +def reference_inverse(inp): + inp_cpu = inp.cpu().numpy().astype(np.float64) + inv_ref = np.linalg.inv(inp_cpu) + return torch.from_numpy(inv_ref) + + +def check_case(lib, n, batch, atol, rtol, ftol): + inp = dense_stable_matrix(n=n, batch=batch).to(device) + ref = reference_inverse(inp).to(torch.float64) + out = run_kernel(lib, inp).cpu().to(torch.float64) + + frob_error = torch.sqrt(torch.sum((ref - out) ** 2) / torch.sum(ref**2)) + allclose_ok = np.allclose(out.numpy(), ref.numpy(), atol=atol, rtol=rtol) + frob_ok = bool(frob_error <= ftol) + + nan_count = int(torch.isnan(out).sum().item()) + inf_count = int(torch.isinf(out).sum().item()) + + if allclose_ok and frob_ok: + print(f"[pass] n={n}, batch={batch}, frob={float(frob_error):.3e}") + return None + + msg = ( + f"[fail] n={n}, batch={batch}, frob={float(frob_error):.3e}, " + f"nan={nan_count}, inf={inf_count}" + ) + print(msg) + return msg + + +def run_test(lib, n, batch_list): + failures = [] + for batch in batch_list: + failure = check_case( + lib, + n=n, + batch=batch, + atol=6e-3, + rtol=5e-2, + ftol=8e-3, + ) + if failure is not None: + failures.append(failure) + + total = len(batch_list) + print( + f"summary: n={n}, pass={total - len(failures)}, fail={len(failures)}, total={total}" + ) + if failures: + warnings.warn( + f"{len(failures)} cases failed. First: {failures[0]}", + stacklevel=2, + ) + return failures + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--matrix-size", + type=int, + choices=SUPPORTED_MATRIX_SIZES, + default=64, + help="Only validate this dense matrix size n.", + ) + parser.add_argument( + "--lib-path", + type=str, + default="./inverse_lib.so", + help="Shared library path produced by compile.sh.", + ) + args = parser.parse_args() + + device = get_test_device() + torch.npu.set_device(device) + batch_list = [1, 8, 24, 27, 48, 96, 99, 135] + + print(f"\n=== validating kernel: {args.lib_path} ===") + lib = load_lib(args.lib_path) + failures = run_test(lib, n=args.matrix_size, batch_list=batch_list) + print(f"\nfinished tests for n={args.matrix_size}, failures={len(failures)}.") diff --git a/examples/aot/fast_inverse/block_inversion/README.md b/examples/aot/fast_inverse/block_inversion/README.md new file mode 100644 index 00000000..8cf5f345 --- /dev/null +++ b/examples/aot/fast_inverse/block_inversion/README.md @@ -0,0 +1,20 @@ +```bash +bash compile.sh # default matrix size 64 +python run_inverse.py + +bash compile.sh 128 # another supported matrix size +python run_inverse.py --matrix-size 128 --lib-path ./inverse_lib.so +``` + +This demo implements one-level 2x2 block inversion for `inv(I + A)` with input shape +`[batch, n, n]`: + +- `A` is interpreted as block-lower-triangular: + `[[A11, 0], [A21, A22]]`, with `A11/A22` size `n/2`. +- `inv(I + A11)` and `inv(I + A22)` are computed by the same fast recurrence used in + the `basic_dense` / `block_diag` demos. +- `A21` block is recovered by `-inv(I + A22) @ A21 @ inv(I + A11)`. + +`run_inverse.py` includes: +- correctness checks on structured random / ill-conditioned generators +- a precision report line in the note style: `c= | error = ...` diff --git a/examples/aot/fast_inverse/block_inversion/caller.cpp b/examples/aot/fast_inverse/block_inversion/caller.cpp new file mode 100644 index 00000000..5ee9b5ca --- /dev/null +++ b/examples/aot/fast_inverse/block_inversion/caller.cpp @@ -0,0 +1,24 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "inverse.cpp" +#endif + +#ifndef KERNEL_FN +#define KERNEL_FN tri_inv_block2x2_fp16 +#endif + +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, + void *stream, + uint8_t *tensor_out, + uint8_t *tensor_in, + uint8_t *identity_in, + uint32_t log2_blocksize) +{ + KERNEL_FN<<>>( + reinterpret_cast(tensor_out), + reinterpret_cast(tensor_in), + reinterpret_cast(identity_in), + static_cast(log2_blocksize)); +} diff --git a/examples/aot/fast_inverse/block_inversion/compile.sh b/examples/aot/fast_inverse/block_inversion/compile.sh new file mode 100644 index 00000000..4d60f8c9 --- /dev/null +++ b/examples/aot/fast_inverse/block_inversion/compile.sh @@ -0,0 +1,33 @@ +#!/usr/bin/env bash +set -euo pipefail + +ARTIFACT_DIR="./build_artifacts" +MATRIX_SIZE="${1:-64}" + +mkdir -p "${ARTIFACT_DIR}" +rm -f "${ARTIFACT_DIR}/inverse.pto" "${ARTIFACT_DIR}/inverse.cpp" inverse_lib.so + +python ./inverse_builder.py \ + --matrix-size "${MATRIX_SIZE}" \ + > "${ARTIFACT_DIR}/inverse.pto" + +ptoas --enable-insert-sync "${ARTIFACT_DIR}/inverse.pto" -o "${ARTIFACT_DIR}/inverse.cpp" + +PTO_LIB_PATH=/sources/pto-isa +# PTO_LIB_PATH=$ASCEND_TOOLKIT_HOME + +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + -DKERNEL_CPP="\"${ARTIFACT_DIR}/inverse.cpp\"" \ + ./caller.cpp \ + -o ./inverse_lib.so diff --git a/examples/aot/fast_inverse/block_inversion/inverse_builder.py b/examples/aot/fast_inverse/block_inversion/inverse_builder.py new file mode 100644 index 00000000..1aa4ec4a --- /dev/null +++ b/examples/aot/fast_inverse/block_inversion/inverse_builder.py @@ -0,0 +1,235 @@ +# pyright: reportUndefinedVariable=false +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const +SUPPORTED_MATRIX_SIZES = (16, 32, 64, 128) + + +def make_meta_data(n: int): + h = n // 2 + + def meta_data(): + in_dtype = pto.float16 + out_dtype = pto.float32 + i32 = pto.int32 + + in_ptr_type = pto.PtrType(in_dtype) + out_ptr_type = pto.PtrType(out_dtype) + in_tensor_type = pto.TensorType(rank=2, dtype=in_dtype) + out_tensor_type = pto.TensorType(rank=2, dtype=out_dtype) + + in_subtensor_h = pto.SubTensorType(shape=[h, h], dtype=in_dtype) + out_subtensor_h = pto.SubTensorType(shape=[h, h], dtype=out_dtype) + + l1_tile_type = pto.TileBufType( + shape=[h, h], valid_shape=[h, h], dtype=in_dtype, memory_space="MAT" + ) + l0a_tile_type = pto.TileBufType( + shape=[h, h], valid_shape=[h, h], dtype=in_dtype, memory_space="LEFT" + ) + l0b_tile_type = pto.TileBufType( + shape=[h, h], valid_shape=[h, h], dtype=in_dtype, memory_space="RIGHT" + ) + l0c_tile_type = pto.TileBufType( + shape=[h, h], valid_shape=[h, h], dtype=out_dtype, memory_space="ACC" + ) + + return { + "in_ptr_type": in_ptr_type, + "out_ptr_type": out_ptr_type, + "i32": i32, + "in_tensor_type": in_tensor_type, + "out_tensor_type": out_tensor_type, + "in_subtensor_h": in_subtensor_h, + "out_subtensor_h": out_subtensor_h, + "l1_tile_type": l1_tile_type, + "l0a_tile_type": l0a_tile_type, + "l0b_tile_type": l0b_tile_type, + "l0c_tile_type": l0c_tile_type, + } + + return meta_data + + +def build_kernel(matrix_size: int): + assert matrix_size % 2 == 0 and matrix_size >= 16 + + @to_ir_module(meta_data=make_meta_data(matrix_size)) + def tri_inv_block2x2_fp16( + out_ptr: "out_ptr_type", + in_ptr: "in_ptr_type", + i_neg_ptr: "in_ptr_type", + log2_blocksize_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + n_c = const(matrix_size) + h_c = const(matrix_size // 2) + + log2_half = s.index_cast(log2_blocksize_i32) - c1 + block_idx = s.index_cast(pto.get_block_idx()) + num_blocks = s.index_cast(pto.get_block_num()) + + total_rows = num_blocks * n_c + row_offset = block_idx * n_c + row_offset_h = row_offset + h_c + + tv_in = pto.as_tensor( + in_tensor_type, ptr=in_ptr, shape=[total_rows, n_c], strides=[n_c, c1] + ) + tv_out = pto.as_tensor( + out_tensor_type, ptr=out_ptr, shape=[total_rows, n_c], strides=[n_c, c1] + ) + tv_i_neg = pto.as_tensor( + in_tensor_type, ptr=i_neg_ptr, shape=[h_c, h_c], strides=[h_c, c1] + ) + sv_i_neg = pto.slice_view( + in_subtensor_h, source=tv_i_neg, offsets=[c0, c0], sizes=[h_c, h_c] + ) + + sv_a11 = pto.slice_view( + in_subtensor_h, source=tv_in, offsets=[row_offset, c0], sizes=[h_c, h_c] + ) + sv_a21 = pto.slice_view( + in_subtensor_h, + source=tv_in, + offsets=[row_offset_h, c0], + sizes=[h_c, h_c], + ) + sv_a22 = pto.slice_view( + in_subtensor_h, + source=tv_in, + offsets=[row_offset_h, h_c], + sizes=[h_c, h_c], + ) + + sv_out11 = pto.slice_view( + out_subtensor_h, + source=tv_out, + offsets=[row_offset, c0], + sizes=[h_c, h_c], + ) + sv_out21 = pto.slice_view( + out_subtensor_h, + source=tv_out, + offsets=[row_offset_h, c0], + sizes=[h_c, h_c], + ) + sv_out22 = pto.slice_view( + out_subtensor_h, + source=tv_out, + offsets=[row_offset_h, h_c], + sizes=[h_c, h_c], + ) + + x11_l1 = pto.alloc_tile(l1_tile_type) + y11_l1 = pto.alloc_tile(l1_tile_type) + x22_l1 = pto.alloc_tile(l1_tile_type) + y22_l1 = pto.alloc_tile(l1_tile_type) + a21_l1 = pto.alloc_tile(l1_tile_type) + neg_i_l1 = pto.alloc_tile(l1_tile_type) + pos_i_l1 = pto.alloc_tile(l1_tile_type) + tmp_l1 = pto.alloc_tile(l1_tile_type) + + a_l0 = pto.alloc_tile(l0a_tile_type) + b_l0 = pto.alloc_tile(l0b_tile_type) + c_l0 = pto.alloc_tile(l0c_tile_type) + + # Build +/- identity tiles for half-size blocks. + # Also seed x11 = x22 = I for the recurrence below. + pto.load(sv_i_neg, neg_i_l1) + tile.mov(neg_i_l1, a_l0) + tile.mov(neg_i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, pos_i_l1) + tile.mov(c_l0, x11_l1) # x11 = I + tile.mov(c_l0, x22_l1) # x22 = I + + # Invert (I + A11): start the recurrence with y11 = -A11, x11 = I. + # The loop then computes x_{k+1} = x_k(I + y_k), y_{k+1} = y_k^2 + # which gives (I + A11)^{-1} after log2_half steps. + pto.load(sv_a11, y11_l1) + tile.mov(y11_l1, a_l0) + tile.mov(neg_i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) # c = -A11 + tile.mov(c_l0, y11_l1) # y11 = -A11 + + for iter_idx in pto.range(c0, log2_half, c1): + tile.mov(x11_l1, a_l0) + tile.mov(pos_i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + + tile.mov(y11_l1, b_l0) + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) + + with pto.if_context(iter_idx + c1 < log2_half): + tile.mov(c_l0, x11_l1) + tile.mov(y11_l1, a_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, y11_l1) + + tile.mov(c_l0, x11_l1) + pto.store(c_l0, sv_out11) + + # Invert (I + A22): start with y22 = -A22, x22 = I (already set above). + pto.load(sv_a22, y22_l1) + tile.mov(y22_l1, a_l0) + tile.mov(neg_i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) # c = -A22 + tile.mov(c_l0, y22_l1) # y22 = -A22 + + for iter_idx in pto.range(c0, log2_half, c1): + tile.mov(x22_l1, a_l0) + tile.mov(pos_i_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + + tile.mov(y22_l1, b_l0) + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) + + with pto.if_context(iter_idx + c1 < log2_half): + tile.mov(c_l0, x22_l1) + tile.mov(y22_l1, a_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, y22_l1) + + tile.mov(c_l0, x22_l1) + pto.store(c_l0, sv_out22) + + # A21 term in block inversion: + # X21 = - X22 @ A21 @ X11 + pto.load(sv_a21, a21_l1) + + tile.mov(x22_l1, a_l0) + tile.mov(a21_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, tmp_l1) + + tile.mov(tmp_l1, a_l0) + tile.mov(x11_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + tile.mov(c_l0, tmp_l1) + + tile.mov(neg_i_l1, a_l0) + tile.mov(tmp_l1, b_l0) + tile.matmul(a_l0, b_l0, c_l0) + pto.store(c_l0, sv_out21) + + return tri_inv_block2x2_fp16 + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--matrix-size", + type=int, + choices=SUPPORTED_MATRIX_SIZES, + default=64, + help="Compile-time specialized matrix size.", + ) + args = parser.parse_args() + module = build_kernel(args.matrix_size) + print(module) diff --git a/examples/aot/fast_inverse/block_inversion/run_inverse.py b/examples/aot/fast_inverse/block_inversion/run_inverse.py new file mode 100644 index 00000000..85551c08 --- /dev/null +++ b/examples/aot/fast_inverse/block_inversion/run_inverse.py @@ -0,0 +1,215 @@ +import argparse +import ctypes +import math +import random +import warnings + +import numpy as np +import torch +import torch_npu # noqa: F401 + +from ptodsl.test_util import get_test_device + +random.seed(42) +torch.manual_seed(42) +np.random.seed(42) + +SUPPORTED_MATRIX_SIZES = (16, 32, 64, 128) +UNIFORM_ATOL = 1e-3 +UNIFORM_RTOL = 1e-3 +UNIFORM_FTOL = 1e-3 + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path): + lib = ctypes.CDLL(lib_path) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, # blockDim (batch) + ctypes.c_void_p, # stream + ctypes.c_void_p, # out + ctypes.c_void_p, # in_delta + ctypes.c_void_p, # identity_neg_half + ctypes.c_uint32, # log2(matrix_size) + ] + lib.call_kernel.restype = None + return lib + + +def ill_matrix(n, batch, offdiag=0.5): + out = np.zeros((batch, n, n), dtype=np.float32) + for b in range(batch): + out[b] = offdiag * np.tril(np.ones((n, n), dtype=np.float32), k=-1) + return torch.from_numpy(out) + + +def structured_random_matrix(n, batch, scale=0.1): + h = n // 2 + out = np.zeros((batch, n, n), dtype=np.float32) + for b in range(batch): + a11 = scale * np.tril( + np.random.uniform(-1.0, 1.0, size=(h, h)).astype(np.float32), k=-1 + ) + a22 = scale * np.tril( + np.random.uniform(-1.0, 1.0, size=(h, h)).astype(np.float32), k=-1 + ) + a21 = scale * np.random.uniform(-1.0, 1.0, size=(h, h)).astype(np.float32) + out[b, :h, :h] = a11 + out[b, h:, h:] = a22 + out[b, h:, :h] = a21 + return torch.from_numpy(out) + + +def structured_scale_by_n(n): + # Keep larger matrices closer to identity so the trend follows the note: + # medium sizes are very accurate, while the hardest ill-conditioned cases + # degrade only at larger n. + return { + 16: 0.10, + 32: 0.08, + 64: 0.05, + 128: 0.03, + }[n] + + +def ill_offdiag_for_tests(n): + # Use a smaller scale for bigger sizes. + return { + 16: 0.2, + 32: 0.1, + 64: 0.05, + 128: 0.02, + }[n] + + +def run_kernel(lib, inp_delta): + inp_fp16 = inp_delta.to(torch.float16).contiguous() + n = int(inp_fp16.shape[-1]) + batch = int(inp_fp16.shape[0]) + h = n // 2 + log2_blocksize = int(math.log2(n)) + + identity_neg_half = torch.zeros((h, h), dtype=torch.float16, device=inp_fp16.device) + identity_neg_half.fill_diagonal_(-1) + out = torch.zeros((batch, n, n), dtype=torch.float32, device=inp_fp16.device) + + stream_ptr = torch.npu.current_stream()._as_parameter_ + lib.call_kernel( + batch, + stream_ptr, + torch_to_ctypes(out), + torch_to_ctypes(inp_fp16), + torch_to_ctypes(identity_neg_half), + log2_blocksize, + ) + torch.npu.synchronize() + return out + + +def reference_inverse(inp_delta): + n = inp_delta.shape[-1] + identity = np.eye(n, dtype=np.float64) + inp_cpu = inp_delta.cpu().numpy().astype(np.float64) + return torch.from_numpy(np.linalg.inv(inp_cpu + identity)) + + +def check_case(lib, matrix_gen, n, batch, atol, rtol, ftol): + inp_delta = matrix_gen(n=n, batch=batch).to(device) + ref = reference_inverse(inp_delta).to(torch.float64) + out = run_kernel(lib, inp_delta).cpu().to(torch.float64) + + frob_error = torch.sqrt(torch.sum((ref - out) ** 2) / torch.sum(ref**2)) + allclose_ok = np.allclose(out.numpy(), ref.numpy(), atol=atol, rtol=rtol) + frob_ok = bool(frob_error <= ftol) + + nan_count = int(torch.isnan(out).sum().item()) + inf_count = int(torch.isinf(out).sum().item()) + + if allclose_ok and frob_ok: + print(f"[pass] n={n}, batch={batch}, frob={float(frob_error):.3e}") + return None + + msg = ( + f"[fail] n={n}, batch={batch}, frob={float(frob_error):.3e}, " + f"nan={nan_count}, inf={inf_count}" + ) + print(msg) + return msg + + +def run_test(lib, n): + failures = [] + structured_scale = structured_scale_by_n(n) + ill_offdiag = ill_offdiag_for_tests(n) + atol, rtol, ftol = UNIFORM_ATOL, UNIFORM_RTOL, UNIFORM_FTOL + structured_batches = [1, 4, 16, 24, 27, 48, 96, 99, 135] + ill_batches = [1, 4, 27] + + for batch in structured_batches: + failure = check_case( + lib, + matrix_gen=lambda n, batch: structured_random_matrix( + n=n, batch=batch, scale=structured_scale + ), + n=n, + batch=batch, + atol=atol, + rtol=rtol, + ftol=ftol, + ) + if failure is not None: + failures.append(failure) + + for batch in ill_batches: + failure = check_case( + lib, + matrix_gen=lambda n, batch: ill_matrix( + n=n, batch=batch, offdiag=ill_offdiag + ), + n=n, + batch=batch, + atol=atol, + rtol=rtol, + ftol=ftol, + ) + if failure is not None: + failures.append(failure) + + total_cases = len(structured_batches) + len(ill_batches) + print( + f"summary: n={n}, pass={total_cases - len(failures)}, " + f"fail={len(failures)}, total={total_cases}" + ) + + if failures: + warnings.warn( + f"{len(failures)} cases failed. First: {failures[0]}", + stacklevel=2, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--matrix-size", + type=int, + choices=SUPPORTED_MATRIX_SIZES, + default=64, + help="Only validate this matrix size n.", + ) + parser.add_argument( + "--lib-path", + type=str, + default="./inverse_lib.so", + help="Shared library path produced by compile.sh.", + ) + args = parser.parse_args() + + device = get_test_device() + torch.npu.set_device(device) + + kernel_lib = load_lib(args.lib_path) + run_test(kernel_lib, n=args.matrix_size) + print(f"Finished tests for n={args.matrix_size} with {args.lib_path}.") diff --git a/examples/aot/matmul_mxfp8/matmul_mxfp8_builder.py b/examples/aot/matmul_mxfp8/matmul_mxfp8_builder.py new file mode 100644 index 00000000..68990645 --- /dev/null +++ b/examples/aot/matmul_mxfp8/matmul_mxfp8_builder.py @@ -0,0 +1,111 @@ +from ptodsl import to_ir_module +import ptodsl.language as pto + + +def build(M=16, K=64, N=32, lhs_variant="e5m2", rhs_variant="e5m2"): + def meta_data(): + mx = pto.make_mxfp8(lhs=lhs_variant, rhs=rhs_variant) + scale_k = mx.scale_k(K) + + ptr_lhs = pto.PtrType(mx.lhs) + ptr_rhs = pto.PtrType(mx.rhs) + ptr_scale = pto.PtrType(mx.scale) + ptr_bias = pto.PtrType(mx.acc) + + lhs_tensor = pto.TensorType(rank=2, dtype=mx.lhs) + rhs_tensor = pto.TensorType(rank=2, dtype=mx.rhs) + lhs_scale_tensor = pto.TensorType(rank=2, dtype=mx.scale) + rhs_scale_tensor = pto.TensorType(rank=2, dtype=mx.scale) + bias_tensor = pto.TensorType(rank=2, dtype=mx.acc) + + lhs_tile_view = pto.SubTensorType(shape=[M, K], dtype=mx.lhs) + rhs_tile_view = pto.SubTensorType(shape=[K, N], dtype=mx.rhs) + lhs_scale_tile_view = pto.SubTensorType(shape=[M, scale_k], dtype=mx.scale) + rhs_scale_tile_view = pto.SubTensorType(shape=[scale_k, N], dtype=mx.scale) + bias_tile_view = pto.SubTensorType(shape=[1, N], dtype=mx.acc) + + lhs_tile = pto.TileBufType(shape=[M, K], dtype=mx.lhs, memory_space="LEFT") + rhs_tile = pto.TileBufType(shape=[K, N], dtype=mx.rhs, memory_space="RIGHT") + lhs_scale_tile = pto.LeftScaleTileBufType(shape=[M, scale_k], dtype=mx.scale) + rhs_scale_tile = pto.RightScaleTileBufType(shape=[scale_k, N], dtype=mx.scale) + bias_tile = pto.TileBufType(shape=[1, N], dtype=mx.acc, memory_space="BIAS") + acc_tile = pto.TileBufType(shape=[M, N], dtype=mx.acc, memory_space="ACC") + + return locals() + + const = pto.const + + @to_ir_module(meta_data=meta_data) + def matmul_mxfp8( + a_ptr: "ptr_lhs", + a_scale_ptr: "ptr_scale", + b_ptr: "ptr_rhs", + b_scale_ptr: "ptr_scale", + bias_ptr: "ptr_bias", + ) -> None: + c0 = const(0) + c1 = const(1) + cM = const(M) + cK = const(K) + cN = const(N) + cScaleK = const(scale_k) + + tv_a = pto.as_tensor(lhs_tensor, ptr=a_ptr, shape=[cM, cK], strides=[cK, c1]) + tv_b = pto.as_tensor(rhs_tensor, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1]) + tv_scale_a = pto.as_tensor( + lhs_scale_tensor, + ptr=a_scale_ptr, + shape=[cM, cScaleK], + strides=[cScaleK, c1], + ) + tv_scale_b = pto.as_tensor( + rhs_scale_tensor, ptr=b_scale_ptr, shape=[cScaleK, cN], strides=[cN, c1] + ) + tv_bias = pto.as_tensor( + bias_tensor, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1] + ) + + sv_a = pto.slice_view( + lhs_tile_view, source=tv_a, offsets=[c0, c0], sizes=[cM, cK] + ) + sv_b = pto.slice_view( + rhs_tile_view, source=tv_b, offsets=[c0, c0], sizes=[cK, cN] + ) + sv_scale_a = pto.slice_view( + lhs_scale_tile_view, + source=tv_scale_a, + offsets=[c0, c0], + sizes=[cM, cScaleK], + ) + sv_scale_b = pto.slice_view( + rhs_scale_tile_view, + source=tv_scale_b, + offsets=[c0, c0], + sizes=[cScaleK, cN], + ) + sv_bias = pto.slice_view( + bias_tile_view, source=tv_bias, offsets=[c0, c0], sizes=[c1, cN] + ) + + with pto.cube_section(): + a_tile = pto.alloc_tile(lhs_tile) + b_tile = pto.alloc_tile(rhs_tile) + a_scale_tile = pto.alloc_tile(lhs_scale_tile) + b_scale_tile = pto.alloc_tile(rhs_scale_tile) + bias_tile_buf = pto.alloc_tile(bias_tile) + acc_tile_buf = pto.alloc_tile(acc_tile) + + pto.load(sv_a, a_tile) + pto.load(sv_b, b_tile) + pto.load(sv_scale_a, a_scale_tile) + pto.load(sv_scale_b, b_scale_tile) + pto.load(sv_bias, bias_tile_buf) + pto.matmul_mx_bias( + a_tile, a_scale_tile, b_tile, b_scale_tile, bias_tile_buf, acc_tile_buf + ) + + return matmul_mxfp8 + + +if __name__ == "__main__": + print(build()) diff --git a/examples/aot/matmul_mxfp8/mxfp8_ppt_example.py b/examples/aot/matmul_mxfp8/mxfp8_ppt_example.py new file mode 100644 index 00000000..603542b0 --- /dev/null +++ b/examples/aot/matmul_mxfp8/mxfp8_ppt_example.py @@ -0,0 +1,87 @@ +from ptodsl import to_ir_module +import ptodsl.language as pto + +M, K, N = 16, 64, 32 + + +def meta_data(): + # 1) 选择 MXFP8 组合。默认是 lhs=e5m2, rhs=e5m2, scale=e8m0, acc=f32。 + mx = pto.make_mxfp8(lhs="e5m2", rhs="e5m2") + scale_k = mx.scale_k(K) # MXFP8 的 scale 张量沿 K 维按 32:1 压缩 + + # 2) 全局输入指针类型 + a_ptr = pto.PtrType(mx.lhs) + b_ptr = pto.PtrType(mx.rhs) + scale_ptr = pto.PtrType(mx.scale) + + # 3) TensorView 类型 + a_tensor = pto.TensorType(rank=2, dtype=mx.lhs) + b_tensor = pto.TensorType(rank=2, dtype=mx.rhs) + scale_a_tensor = pto.TensorType(rank=2, dtype=mx.scale) + scale_b_tensor = pto.TensorType(rank=2, dtype=mx.scale) + + # 4) TileView / TileBuf 类型 + a_view = pto.SubTensorType(shape=[M, K], dtype=mx.lhs) + b_view = pto.SubTensorType(shape=[K, N], dtype=mx.rhs) + scale_a_view = pto.SubTensorType(shape=[M, scale_k], dtype=mx.scale) + scale_b_view = pto.SubTensorType(shape=[scale_k, N], dtype=mx.scale) + + a_tile = pto.TileBufType(shape=[M, K], dtype=mx.lhs, memory_space="LEFT") + b_tile = pto.TileBufType(shape=[K, N], dtype=mx.rhs, memory_space="RIGHT") + scale_a_tile = pto.LeftScaleTileBufType(shape=[M, scale_k], dtype=mx.scale) + scale_b_tile = pto.RightScaleTileBufType(shape=[scale_k, N], dtype=mx.scale) + acc_tile = pto.TileBufType(shape=[M, N], dtype=mx.acc, memory_space="ACC") + + return locals() + + +@to_ir_module(meta_data=meta_data) +def matmul_mxfp8_core( + a: "a_ptr", + scale_a: "scale_ptr", + b: "b_ptr", + scale_b: "scale_ptr", +) -> None: + c0 = pto.const(0) + c1 = pto.const(1) + cM = pto.const(M) + cK = pto.const(K) + cN = pto.const(N) + cScaleK = pto.const(scale_k) + + tv_a = pto.as_tensor(a_tensor, ptr=a, shape=[cM, cK], strides=[cK, c1]) + tv_b = pto.as_tensor(b_tensor, ptr=b, shape=[cK, cN], strides=[cN, c1]) + tv_scale_a = pto.as_tensor( + scale_a_tensor, ptr=scale_a, shape=[cM, cScaleK], strides=[cScaleK, c1] + ) + tv_scale_b = pto.as_tensor( + scale_b_tensor, ptr=scale_b, shape=[cScaleK, cN], strides=[cN, c1] + ) + + sv_a = pto.slice_view(a_view, source=tv_a, offsets=[c0, c0], sizes=[cM, cK]) + sv_b = pto.slice_view(b_view, source=tv_b, offsets=[c0, c0], sizes=[cK, cN]) + sv_scale_a = pto.slice_view( + scale_a_view, source=tv_scale_a, offsets=[c0, c0], sizes=[cM, cScaleK] + ) + sv_scale_b = pto.slice_view( + scale_b_view, source=tv_scale_b, offsets=[c0, c0], sizes=[cScaleK, cN] + ) + + with pto.cube_section(): + ta = pto.alloc_tile(a_tile) + tb = pto.alloc_tile(b_tile) + tsa = pto.alloc_tile(scale_a_tile) + tsb = pto.alloc_tile(scale_b_tile) + tc = pto.alloc_tile(acc_tile) + + pto.load(sv_a, ta) + pto.load(sv_b, tb) + pto.load(sv_scale_a, tsa) + pto.load(sv_scale_b, tsb) + + # 核心调用:MXFP8 data tile + scale tile -> Acc tile + pto.matmul_mx(ta, tsa, tb, tsb, tc) + + +if __name__ == "__main__": + print(matmul_mxfp8_core) diff --git a/examples/aot/matmul_optimization_guide/.gitignore b/examples/aot/matmul_optimization_guide/.gitignore new file mode 100644 index 00000000..2672482f --- /dev/null +++ b/examples/aot/matmul_optimization_guide/.gitignore @@ -0,0 +1 @@ +build_artifacts diff --git a/examples/aot/matmul_optimization_guide/README.md b/examples/aot/matmul_optimization_guide/README.md new file mode 100644 index 00000000..d3b92e57 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/README.md @@ -0,0 +1,23 @@ +See [matmul_optim_guide.md](./matmul_optim_guide.md) for a step-by-step algorithm walkthrough. + +Usage: + +```bash +# Build all tutorial steps +bash ./compile.sh + +# Run correctness on all steps (default) +python ./run_matmul.py + +# Or run one specific tutorial step +python ./run_matmul.py --variant step1-baseline +python ./run_matmul.py --variant step2-doublebuffer +python ./run_matmul.py --variant step3-swizzle +python ./run_matmul.py --variant step4-manual-pipelining + +# Stepwise benchmark comparisons: +# Step1: double-buffer vs single-buffer (both non-swizzle, auto-sync) +# Step2: swizzle vs non-swizzle (both double-buffer, auto-sync) +# Step3: manual-sync vs auto-sync (both double-buffer, swizzle) +python ./bench_matmul.py +``` diff --git a/examples/aot/matmul_optimization_guide/bench_matmul.py b/examples/aot/matmul_optimization_guide/bench_matmul.py new file mode 100644 index 00000000..5c6cc83e --- /dev/null +++ b/examples/aot/matmul_optimization_guide/bench_matmul.py @@ -0,0 +1,443 @@ +import argparse +import ctypes +import os +from pathlib import Path + +import torch +import torch_npu # noqa: F401 + +from ptodsl.test_util import get_test_device + + +BLOCK_DIM = 24 +M_LIST = [128 * i for i in range(1, 37, 4)] # 128, ..., 4224 +SHAPES_NK = [ + (4096, 4096), + (8192, 8192), + (16384, 16384), +] +N_WARMUP = 5 +N_REPEAT = 20 +PLOT_SHAPES_NK = [(8192, 8192), (16384, 16384)] +DEFAULT_PLOT_DIR = Path("fig") + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path): + lib = ctypes.CDLL(os.path.abspath(lib_path)) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ] + lib.call_kernel.restype = None + + def matmul_abt(a, b, *, block_dim=BLOCK_DIM, stream_ptr=None): + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + + m = int(a.shape[0]) + k = int(a.shape[1]) + n = int(b.shape[0]) + c = torch.empty((m, n), device=a.device, dtype=a.dtype) + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(a), + torch_to_ctypes(b), + torch_to_ctypes(c), + m, + n, + k, + ) + return c + + return matmul_abt + + +def _parse_int_list(raw): + parts = [p.strip() for p in raw.split(",") if p.strip()] + if not parts: + raise ValueError("List cannot be empty.") + return [int(p) for p in parts] + + +def _time_us(fn, a_list, b_list, warmup, repeat): + for a, b in zip(a_list[:warmup], b_list[:warmup]): + fn(a, b) + torch.npu.synchronize() + + start = torch.npu.Event(enable_timing=True) + end = torch.npu.Event(enable_timing=True) + start.record() + for a, b in zip(a_list[warmup : warmup + repeat], b_list[warmup : warmup + repeat]): + fn(a, b) + end.record() + torch.npu.synchronize() + return start.elapsed_time(end) * 1000.0 / repeat + + +def _maybe_plot(rows, plot_dir): + try: + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed; skipping plot generation.") + return + + style_candidates = ("seaborn-v0_8-whitegrid", "seaborn-whitegrid") + for style_name in style_candidates: + try: + plt.style.use(style_name) + break + except OSError: + continue + + plt.rcParams["figure.facecolor"] = "white" + plt.rcParams["axes.facecolor"] = "white" + plot_dir.mkdir(parents=True, exist_ok=True) + title_scale = 1.5 + axis_label_scale = 1.5 + legend_scale = 2.0 + + step_defs = [ + ( + "step1", + "single_auto_noswizzle_tflops", + "Step1 Kernel", + "flops_step1_baseline.png", + ), + ( + "step2", + "double_auto_noswizzle_tflops", + "Step2 Kernel", + "flops_step2_doublebuf.png", + ), + ( + "step3", + "double_auto_swizzle_tflops", + "Step3 Kernel", + "flops_step3_swizzle.png", + ), + ( + "step4", + "double_manual_swizzle_tflops", + "Step4 Kernel", + "flops_step4_manual_pipeline.png", + ), + ] + + for _, custom_key, custom_label, out_name in step_defs: + fig, axes = plt.subplots(1, 2, figsize=(12, 5), sharey=True) + for ax, (n, k) in zip(axes, PLOT_SHAPES_NK): + base_title_size = ax.title.get_size() + base_label_size = ax.xaxis.label.get_size() + chunk = [r for r in rows if r["n"] == n and r["k"] == k] + if not chunk: + ax.set_title( + f"TFLOPS vs M (N={n}, K={k})", + fontsize=base_title_size * title_scale, + ) + ax.text( + 0.5, + 0.5, + "No data", + transform=ax.transAxes, + ha="center", + va="center", + ) + ax.set_xlabel("M", fontsize=base_label_size * axis_label_scale) + ax.set_ylabel("TFLOPS", fontsize=base_label_size * axis_label_scale) + ax.grid(alpha=0.25) + continue + + chunk = sorted(chunk, key=lambda r: r["m"]) + m_values = [r["m"] for r in chunk] + matmul_tflops = [r["torch_matmul_tflops"] for r in chunk] + custom_tflops = [r[custom_key] for r in chunk] + + ax.plot( + m_values, + matmul_tflops, + marker="x", + linestyle="--", + color="#111111", + label="torch.matmul", + ) + ax.plot( + m_values, + custom_tflops, + marker="o", + linestyle="-", + color="#1f77b4", + label=custom_label, + ) + ax.set_title( + f"TFLOPS vs M (N={n}, K={k})", fontsize=base_title_size * title_scale + ) + ax.set_xlabel("M", fontsize=base_label_size * axis_label_scale) + ax.set_ylabel("TFLOPS", fontsize=base_label_size * axis_label_scale) + ax.set_xlim(left=0) + ax.set_ylim(bottom=0) + ax.grid(alpha=0.25) + ax.legend(fontsize=8 * legend_scale) + + plt.tight_layout() + out = plot_dir / out_name + plt.savefig(out, dpi=160, format="png") + plt.close(fig) + print(f"Saved plot: {out}") + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Stepwise performance benchmark for buffering, swizzle, and manual sync." + ) + parser.add_argument( + "--double-auto-swizzle-lib", + type=str, + default="./build_artifacts/step3_swizzle_kernel.so", + help="Path to double-buffer auto-sync swizzled shared library.", + ) + parser.add_argument( + "--double-auto-noswizzle-lib", + type=str, + default="./build_artifacts/step2_doublebuffer_kernel.so", + help="Path to double-buffer auto-sync non-swizzle shared library.", + ) + parser.add_argument( + "--double-manual-swizzle-lib", + type=str, + default="./build_artifacts/step4_manual_pipelining_kernel.so", + help="Path to double-buffer manual-sync swizzled shared library.", + ) + parser.add_argument( + "--single-auto-noswizzle-lib", + type=str, + default="./build_artifacts/step1_baseline_kernel.so", + help="Path to single-buffer auto-sync non-swizzle shared library.", + ) + parser.add_argument( + "--m-list", + type=str, + default=",".join(str(m) for m in M_LIST), + help="Comma-separated M values (default: script M_LIST).", + ) + parser.add_argument( + "--warmup", + type=int, + default=N_WARMUP, + help=f"Warmup iterations (default: {N_WARMUP}).", + ) + parser.add_argument( + "--repeat", + type=int, + default=N_REPEAT, + help=f"Timed iterations (default: {N_REPEAT}).", + ) + parser.add_argument( + "--plot-dir", + type=str, + default=str(DEFAULT_PLOT_DIR), + help=f"Plot output directory (default: {DEFAULT_PLOT_DIR}).", + ) + return parser.parse_args() + + +def main(): + args = _parse_args() + if args.warmup < 1 or args.repeat < 1: + raise ValueError("--warmup and --repeat must be positive integers.") + + base_dir = Path(__file__).resolve().parent + + double_auto_swizzle_lib = Path(args.double_auto_swizzle_lib) + if not double_auto_swizzle_lib.is_absolute(): + double_auto_swizzle_lib = base_dir / double_auto_swizzle_lib + double_auto_noswizzle_lib = Path(args.double_auto_noswizzle_lib) + if not double_auto_noswizzle_lib.is_absolute(): + double_auto_noswizzle_lib = base_dir / double_auto_noswizzle_lib + double_manual_swizzle_lib = Path(args.double_manual_swizzle_lib) + if not double_manual_swizzle_lib.is_absolute(): + double_manual_swizzle_lib = base_dir / double_manual_swizzle_lib + single_auto_noswizzle_lib = Path(args.single_auto_noswizzle_lib) + if not single_auto_noswizzle_lib.is_absolute(): + single_auto_noswizzle_lib = base_dir / single_auto_noswizzle_lib + if not double_auto_swizzle_lib.exists(): + raise FileNotFoundError( + f"Double-buffer auto-sync swizzle library not found: {double_auto_swizzle_lib}" + ) + if not double_auto_noswizzle_lib.exists(): + raise FileNotFoundError( + f"Double-buffer auto-sync non-swizzle library not found: {double_auto_noswizzle_lib}" + ) + if not double_manual_swizzle_lib.exists(): + raise FileNotFoundError( + f"Double-buffer manual-sync swizzle library not found: {double_manual_swizzle_lib}" + ) + if not single_auto_noswizzle_lib.exists(): + raise FileNotFoundError( + f"Single-buffer auto-sync non-swizzle library not found: {single_auto_noswizzle_lib}" + ) + plot_dir = Path(args.plot_dir) + if not plot_dir.is_absolute(): + plot_dir = base_dir / plot_dir + + device = get_test_device() + torch.npu.set_device(device) + torch.manual_seed(0) + + double_auto_swizzle_mm = load_lib(str(double_auto_swizzle_lib)) + double_auto_noswizzle_mm = load_lib(str(double_auto_noswizzle_lib)) + double_manual_swizzle_mm = load_lib(str(double_manual_swizzle_lib)) + single_auto_noswizzle_mm = load_lib(str(single_auto_noswizzle_lib)) + m_list = _parse_int_list(args.m_list) + + ratios_step1_double_vs_single_noswizzle = [] + ratios_step2_swizzle_vs_noswizzle = [] + ratios_step3_manual_vs_auto_swizzle = [] + plot_rows = [] + print(f"double-buffer auto-sync swizzle lib: {double_auto_swizzle_lib}") + print(f"double-buffer auto-sync non-swizzle lib: {double_auto_noswizzle_lib}") + print(f"double-buffer manual-sync swizzle lib: {double_manual_swizzle_lib}") + print(f"single-buffer auto-sync non-swizzle lib: {single_auto_noswizzle_lib}") + print("") + + for n, k in SHAPES_NK: + print(f"=== N={n}, K={k} ===") + for m in m_list: + alloc = args.warmup + args.repeat + a_list = [ + torch.randn(m, k, dtype=torch.float16, device=device) + for _ in range(alloc) + ] + b_list = [ + torch.randn(n, k, dtype=torch.float16, device=device) + for _ in range(alloc) + ] + + double_auto_swizzle_us = _time_us( + double_auto_swizzle_mm, a_list, b_list, args.warmup, args.repeat + ) + double_auto_noswizzle_us = _time_us( + double_auto_noswizzle_mm, a_list, b_list, args.warmup, args.repeat + ) + double_manual_swizzle_us = _time_us( + double_manual_swizzle_mm, a_list, b_list, args.warmup, args.repeat + ) + single_auto_noswizzle_us = _time_us( + single_auto_noswizzle_mm, a_list, b_list, args.warmup, args.repeat + ) + torch_matmul_us = _time_us( + lambda a, b: torch.matmul(a, b.transpose(0, 1)), + a_list, + b_list, + args.warmup, + args.repeat, + ) + del a_list, b_list + torch.npu.empty_cache() + + flops = 2.0 * m * n * k + double_auto_swizzle_tflops = flops / double_auto_swizzle_us / 1e6 + double_auto_noswizzle_tflops = flops / double_auto_noswizzle_us / 1e6 + double_manual_swizzle_tflops = flops / double_manual_swizzle_us / 1e6 + single_auto_noswizzle_tflops = flops / single_auto_noswizzle_us / 1e6 + torch_matmul_tflops = flops / torch_matmul_us / 1e6 + + # Step 1: buffering effect (double-buffer vs single-buffer, both non-swizzle auto-sync). + step1_double_vs_single = ( + double_auto_noswizzle_tflops / single_auto_noswizzle_tflops + ) + # Step 2: swizzle effect (double-buffer auto-sync swizzle vs non-swizzle). + step2_swizzle_vs_noswizzle = ( + double_auto_swizzle_tflops / double_auto_noswizzle_tflops + ) + # Step 3: manual-sync effect (double-buffer swizzle manual-sync vs auto-sync). + step3_manual_vs_auto = ( + double_manual_swizzle_tflops / double_auto_swizzle_tflops + ) + + ratios_step1_double_vs_single_noswizzle.append(step1_double_vs_single) + ratios_step2_swizzle_vs_noswizzle.append(step2_swizzle_vs_noswizzle) + ratios_step3_manual_vs_auto_swizzle.append(step3_manual_vs_auto) + plot_rows.append( + { + "m": m, + "n": n, + "k": k, + "torch_matmul_tflops": torch_matmul_tflops, + "single_auto_noswizzle_tflops": single_auto_noswizzle_tflops, + "double_auto_noswizzle_tflops": double_auto_noswizzle_tflops, + "double_auto_swizzle_tflops": double_auto_swizzle_tflops, + "double_manual_swizzle_tflops": double_manual_swizzle_tflops, + } + ) + + print( + f"(M,N,K)=({m},{n},{k}) " + f"single_noswizzle={single_auto_noswizzle_tflops:.3f}TF, " + f"double_noswizzle_auto={double_auto_noswizzle_tflops:.3f}TF, " + f"double_swizzle_auto={double_auto_swizzle_tflops:.3f}TF, " + f"double_swizzle_manual={double_manual_swizzle_tflops:.3f}TF, " + f"torch_matmul={torch_matmul_tflops:.3f}TF, " + f"step1_ratio(double_noswizzle_auto/single_noswizzle)={step1_double_vs_single:.3f}x, " + f"step2_ratio(double_swizzle_auto/double_noswizzle_auto)={step2_swizzle_vs_noswizzle:.3f}x, " + f"step3_ratio(double_swizzle_manual/double_swizzle_auto)={step3_manual_vs_auto:.3f}x" + ) + print("") + + avg_step1 = sum(ratios_step1_double_vs_single_noswizzle) / len( + ratios_step1_double_vs_single_noswizzle + ) + min_step1 = min(ratios_step1_double_vs_single_noswizzle) + max_step1 = max(ratios_step1_double_vs_single_noswizzle) + avg_step2 = sum(ratios_step2_swizzle_vs_noswizzle) / len( + ratios_step2_swizzle_vs_noswizzle + ) + min_step2 = min(ratios_step2_swizzle_vs_noswizzle) + max_step2 = max(ratios_step2_swizzle_vs_noswizzle) + avg_step3 = sum(ratios_step3_manual_vs_auto_swizzle) / len( + ratios_step3_manual_vs_auto_swizzle + ) + min_step3 = min(ratios_step3_manual_vs_auto_swizzle) + max_step3 = max(ratios_step3_manual_vs_auto_swizzle) + + print("=== Summary ===") + print("Step1 (double-buffer speedup, both non-swizzle auto-sync):") + print(f"avg FLOP ratio(double_noswizzle_auto/single_noswizzle): {avg_step1:.3f}x") + print(f"min FLOP ratio(double_noswizzle_auto/single_noswizzle): {min_step1:.3f}x") + print(f"max FLOP ratio(double_noswizzle_auto/single_noswizzle): {max_step1:.3f}x") + print("Step2 (swizzle speedup, both double-buffer auto-sync):") + print( + f"avg FLOP ratio(double_swizzle_auto/double_noswizzle_auto): {avg_step2:.3f}x" + ) + print( + f"min FLOP ratio(double_swizzle_auto/double_noswizzle_auto): {min_step2:.3f}x" + ) + print( + f"max FLOP ratio(double_swizzle_auto/double_noswizzle_auto): {max_step2:.3f}x" + ) + print("Step3 (manual-sync speedup, both double-buffer swizzle):") + print( + f"avg FLOP ratio(double_swizzle_manual/double_swizzle_auto): {avg_step3:.3f}x" + ) + print( + f"min FLOP ratio(double_swizzle_manual/double_swizzle_auto): {min_step3:.3f}x" + ) + print( + f"max FLOP ratio(double_swizzle_manual/double_swizzle_auto): {max_step3:.3f}x" + ) + + _maybe_plot(plot_rows, plot_dir) + + +if __name__ == "__main__": + main() diff --git a/examples/aot/matmul_optimization_guide/caller.cpp b/examples/aot/matmul_optimization_guide/caller.cpp new file mode 100644 index 00000000..ac10bd3a --- /dev/null +++ b/examples/aot/matmul_optimization_guide/caller.cpp @@ -0,0 +1,28 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "matmul.cpp" +#endif + +#ifndef KERNEL_FN +#define KERNEL_FN matmul_kernel_ABt +#endif + +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, + void *stream, + uint8_t *x, + uint8_t *y, + uint8_t *z, + int M, + int N, + int K) +{ + KERNEL_FN<<>>( + reinterpret_cast(x), + reinterpret_cast(y), + reinterpret_cast(z), + static_cast(M), + static_cast(N), + static_cast(K)); +} diff --git a/examples/aot/matmul_optimization_guide/common_utils.py b/examples/aot/matmul_optimization_guide/common_utils.py new file mode 100644 index 00000000..58d8b801 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/common_utils.py @@ -0,0 +1,76 @@ +from ptodsl import pto +from ptodsl import scalar as s + +const = s.const + +M_TILE = 128 +K_QTILE = 64 +K_TILE = 256 +K_DTILE = 512 +N_FULL = 256 +SWIZZLE_COUNT = 5 + + +def build_meta_data(): + def meta_data(): + dtype = pto.float16 + acc_dtype = pto.float32 + ptr_type = pto.PtrType(dtype) + i32 = pto.int32 + tv_2d = pto.TensorType(rank=2, dtype=dtype) + + tile_view_a = pto.SubTensorType(shape=[M_TILE, K_DTILE], dtype=dtype) + tile_view_b = pto.SubTensorType(shape=[K_TILE, N_FULL], dtype=dtype) + tile_view_c = pto.SubTensorType(shape=[M_TILE, N_FULL], dtype=dtype) + + b_l1_cfg = pto.TileBufConfig( + blayout="RowMajor", slayout="ColMajor", s_fractal_size=512 + ) + + tile_buf_a_l1 = pto.TileBufType( + shape=[M_TILE, K_DTILE], dtype=dtype, memory_space="MAT" + ) + tile_buf_b_l1 = pto.TileBufType( + shape=[K_TILE, N_FULL], dtype=dtype, memory_space="MAT", config=b_l1_cfg + ) + tile_buf_a_l0 = pto.TileBufType( + shape=[M_TILE, K_QTILE], dtype=dtype, memory_space="LEFT" + ) + tile_buf_b_l0 = pto.TileBufType( + shape=[K_QTILE, N_FULL], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_c = pto.TileBufType( + shape=[M_TILE, N_FULL], dtype=acc_dtype, memory_space="ACC" + ) + + return { + "ptr_type": ptr_type, + "i32": i32, + "tv_2d": tv_2d, + "tile_view_a": tile_view_a, + "tile_view_b": tile_view_b, + "tile_view_c": tile_view_c, + "tile_buf_a_l1": tile_buf_a_l1, + "tile_buf_b_l1": tile_buf_b_l1, + "tile_buf_a_l0": tile_buf_a_l0, + "tile_buf_b_l0": tile_buf_b_l0, + "tile_buf_c": tile_buf_c, + } + + return meta_data + + +def swizzle_nz(li, m_loop, n_loop, c_swizzle, c_swizzle_m1, c1, c2): + tile_block_loop = (n_loop + c_swizzle_m1) // c_swizzle + tile_block_span = c_swizzle * m_loop + tile_block_idx = li // tile_block_span + in_tile_block_idx = li % tile_block_span + is_last_block = tile_block_idx == (tile_block_loop - c1) + n_col_tail = n_loop - c_swizzle * tile_block_idx + n_col = s.select(is_last_block, n_col_tail, c_swizzle) + m_idx = in_tile_block_idx // n_col + n_idx = tile_block_idx * c_swizzle + (in_tile_block_idx % n_col) + odd_block = (tile_block_idx % c2) == c1 + flipped_m_idx = m_loop - m_idx - c1 + m_idx = s.select(odd_block, flipped_m_idx, m_idx) + return m_idx, n_idx diff --git a/examples/aot/matmul_optimization_guide/compile.sh b/examples/aot/matmul_optimization_guide/compile.sh new file mode 100644 index 00000000..d1a60813 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/compile.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +set -euo pipefail + +ARTIFACT_DIR="./build_artifacts" +mkdir -p "${ARTIFACT_DIR}" + +rm -f "${ARTIFACT_DIR}"/*.pto "${ARTIFACT_DIR}"/*.cpp "${ARTIFACT_DIR}"/*.so + +# Step1 baseline: functionally correct dynamic-shape matmul without optimizations. +python ./step1_baseline.py > "${ARTIFACT_DIR}/step1_baseline.pto" +ptoas --enable-insert-sync "${ARTIFACT_DIR}/step1_baseline.pto" -o "${ARTIFACT_DIR}/step1_baseline.cpp" + +bisheng -fPIC -shared -xcce -O2 -std=c++17 \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -I"${ASCEND_TOOLKIT_HOME}/include" \ + -DKERNEL_CPP="\"${ARTIFACT_DIR}/step1_baseline.cpp\"" \ + -DKERNEL_FN=matmul_kernel_step1_baseline \ + ./caller.cpp \ + -o "${ARTIFACT_DIR}/step1_baseline_kernel.so" + +# Step2: double-buffer only (no swizzle, auto-sync). +python ./step2_doublebuffer.py > "${ARTIFACT_DIR}/step2_doublebuffer.pto" +ptoas --enable-insert-sync "${ARTIFACT_DIR}/step2_doublebuffer.pto" -o "${ARTIFACT_DIR}/step2_doublebuffer.cpp" + +bisheng -fPIC -shared -xcce -O2 -std=c++17 \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -I"${ASCEND_TOOLKIT_HOME}/include" \ + -DKERNEL_CPP="\"${ARTIFACT_DIR}/step2_doublebuffer.cpp\"" \ + -DKERNEL_FN=matmul_kernel_ABt_autosync \ + ./caller.cpp \ + -o "${ARTIFACT_DIR}/step2_doublebuffer_kernel.so" + +# Step3: swizzle + double-buffer (auto-sync). +python ./step3_swizzle.py > "${ARTIFACT_DIR}/step3_swizzle.pto" +ptoas --enable-insert-sync "${ARTIFACT_DIR}/step3_swizzle.pto" -o "${ARTIFACT_DIR}/step3_swizzle.cpp" + +bisheng -fPIC -shared -xcce -O2 -std=c++17 \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -I"${ASCEND_TOOLKIT_HOME}/include" \ + -DKERNEL_CPP="\"${ARTIFACT_DIR}/step3_swizzle.cpp\"" \ + -DKERNEL_FN=matmul_kernel_ABt_autosync \ + ./caller.cpp \ + -o "${ARTIFACT_DIR}/step3_swizzle_kernel.so" + +# Step4: swizzle + double-buffer + manual software pipelining. +python ./step4_manual_pipelining.py > "${ARTIFACT_DIR}/step4_manual_pipelining.pto" +ptoas "${ARTIFACT_DIR}/step4_manual_pipelining.pto" -o "${ARTIFACT_DIR}/step4_manual_pipelining.cpp" + +bisheng -fPIC -shared -xcce -O2 -std=c++17 \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -I"${ASCEND_TOOLKIT_HOME}/include" \ + -DKERNEL_CPP="\"${ARTIFACT_DIR}/step4_manual_pipelining.cpp\"" \ + -DKERNEL_FN=matmul_kernel_ABt \ + ./caller.cpp \ + -o "${ARTIFACT_DIR}/step4_manual_pipelining_kernel.so" diff --git a/examples/aot/matmul_optimization_guide/experimental/.gitignore b/examples/aot/matmul_optimization_guide/experimental/.gitignore new file mode 100644 index 00000000..7b55c120 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/experimental/.gitignore @@ -0,0 +1,3 @@ +outputs +matmul.cpp +matmul.pto diff --git a/examples/aot/matmul_optimization_guide/experimental/README.md b/examples/aot/matmul_optimization_guide/experimental/README.md new file mode 100644 index 00000000..009a4e6a --- /dev/null +++ b/examples/aot/matmul_optimization_guide/experimental/README.md @@ -0,0 +1,21 @@ +Usage: + +```bash +bash ./compile.sh +python ./run_matmul.py + +python ./bench_matmul.py +``` + +Benchmark outputs: +- CSV: `outputs/csv/bench_matmul.csv` +- Optional plots (if `matplotlib` is installed): `outputs/plots/flops_n{N}_k{K}.png` + +Useful benchmark options: + +``` +python ./bench_matmul.py --csv outputs/csv/my_bench.csv --plot-dir outputs/plots +python ./bench_matmul.py --m-list 512,1024,2048,4096 +python ./bench_matmul.py --warmup 10 --repeat 50 +python ./bench_matmul.py --lib ./matmul_kernel.so +``` diff --git a/examples/aot/matmul_optimization_guide/experimental/bench_matmul.py b/examples/aot/matmul_optimization_guide/experimental/bench_matmul.py new file mode 100644 index 00000000..6054f407 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/experimental/bench_matmul.py @@ -0,0 +1,487 @@ +import argparse +import csv +import ctypes +import os +from pathlib import Path + +import torch +import torch.nn.functional as F +import torch_npu # noqa: F401 + +from ptodsl.test_util import get_test_device + + +BLOCK_DIM = 24 +SWIZZLE_DIRECTION_LIST = [0, 1] +SWIZZLE_COUNT_LIST = [1, 3, 5] +NO_SWIZZLE_DIRECTION = -1 +NO_SWIZZLE_COUNT = 1 +M_LIST = [128 * i for i in range(1, 37, 4)] # 128, ..., 4224 +SHAPES_NK = [ + (4096, 4096), + (8192, 8192), + (16384, 16384), +] +N_WARMUP = 5 +N_REPEAT = 20 +DEFAULT_CSV_REL_PATH = Path("outputs") / "csv" / "bench_matmul.csv" +DEFAULT_PLOT_DIR = Path("outputs") / "plots" + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path): + lib = ctypes.CDLL(os.path.abspath(lib_path)) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ] + lib.call_kernel.restype = None + + def matmul_abt( + a, + b, + *, + block_dim=24, + swizzle_direction=1, + swizzle_count=3, + stream_ptr=None, + ): + if a.ndim != 2 or b.ndim != 2: + raise ValueError("matmul_abt expects 2D tensors: a[M,K], b[N,K]") + if a.shape[1] != b.shape[1]: + raise ValueError( + f"K mismatch: a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)}" + ) + if a.dtype != torch.float16 or b.dtype != torch.float16: + raise ValueError("matmul_abt currently supports float16 inputs only") + + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + + m = int(a.shape[0]) + k = int(a.shape[1]) + n = int(b.shape[0]) + c = torch.empty((m, n), device=a.device, dtype=a.dtype) + + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(a), + torch_to_ctypes(b), + torch_to_ctypes(c), + m, + n, + k, + swizzle_direction, + swizzle_count, + ) + return c + + return matmul_abt + + +def _parse_int_list(raw: str): + parts = [p.strip() for p in raw.split(",") if p.strip()] + if not parts: + raise ValueError("List cannot be empty.") + return [int(p) for p in parts] + + +def _parse_args(): + parser = argparse.ArgumentParser( + description=( + "Benchmark AOT matmul_abt vs torch.nn.functional.linear, " + "save CSV, and optionally plot throughput." + ) + ) + parser.add_argument( + "--lib", + type=str, + default="matmul_kernel.so", + help="Path to shared library with call_kernel (default: matmul_kernel.so).", + ) + parser.add_argument( + "--csv", + type=str, + default=str(DEFAULT_CSV_REL_PATH), + help=f"Output CSV path (default: {DEFAULT_CSV_REL_PATH}).", + ) + parser.add_argument( + "--plot-dir", + type=str, + default=str(DEFAULT_PLOT_DIR), + help=f"Plot output directory (default: {DEFAULT_PLOT_DIR}).", + ) + parser.add_argument( + "--m-list", + type=str, + default=",".join(str(m) for m in M_LIST), + help="Comma-separated M values (default: script M_LIST).", + ) + parser.add_argument( + "--warmup", + type=int, + default=N_WARMUP, + help=f"Warmup iterations (default: {N_WARMUP}).", + ) + parser.add_argument( + "--repeat", + type=int, + default=N_REPEAT, + help=f"Timed iterations (default: {N_REPEAT}).", + ) + return parser.parse_args() + + +def _time_fn(fn, a_list, b_list, warmup, repeat): + for a, b in zip(a_list[:warmup], b_list[:warmup]): + fn(a, b) + torch.npu.synchronize() + + start = torch.npu.Event(enable_timing=True) + end = torch.npu.Event(enable_timing=True) + start.record() + for a, b in zip(a_list[warmup : warmup + repeat], b_list[warmup : warmup + repeat]): + fn(a, b) + end.record() + torch.npu.synchronize() + + elapsed_ms = start.elapsed_time(end) + return elapsed_ms * 1000.0 / repeat + + +def _swizzle_cases(): + # direction=-1 disables swizzle; treat it as one dedicated baseline case. + cases = [(NO_SWIZZLE_DIRECTION, NO_SWIZZLE_COUNT)] + for direction in SWIZZLE_DIRECTION_LIST: + if direction == NO_SWIZZLE_DIRECTION: + continue + for count in SWIZZLE_COUNT_LIST: + cases.append((direction, count)) + return cases + + +def _maybe_plot(rows, plot_dir): + try: + import matplotlib.pyplot as plt + except ImportError: + print("matplotlib not installed; skipping plot generation.") + return + + # Prefer a white-grid background style for readability in reports. + style_candidates = ("seaborn-v0_8-whitegrid", "seaborn-whitegrid") + for style_name in style_candidates: + try: + plt.style.use(style_name) + break + except OSError: + continue + + plt.rcParams["figure.facecolor"] = "white" + plt.rcParams["axes.facecolor"] = "white" + + plot_dir.mkdir(parents=True, exist_ok=True) + + grouped = {} + for row in rows: + key = (row["n"], row["k"]) + grouped.setdefault(key, []).append(row) + + for (n, k), chunk in grouped.items(): + m_values = sorted({r["m"] for r in chunk}) + swizzles = sorted( + {(r["swizzle_direction"], r["swizzle_count"]) for r in chunk}, + key=lambda x: (x[0], x[1]), + ) + + linear_by_m = {} + for m in m_values: + candidates = [r for r in chunk if r["m"] == m] + linear_by_m[m] = sum(r["linear_tflops"] for r in candidates) / len( + candidates + ) + + plt.figure(figsize=(9, 5)) + plt.plot( + m_values, + [linear_by_m[m] for m in m_values], + marker="x", + linestyle="--", + color="#111111", + label="F.linear", + ) + + cmap = plt.get_cmap("tab10") + for idx, (direction, count) in enumerate(swizzles): + series = [] + for m in m_values: + candidates = [ + r + for r in chunk + if r["m"] == m + and r["swizzle_direction"] == direction + and r["swizzle_count"] == count + ] + if not candidates: + series.append(float("nan")) + else: + series.append( + sum(r["custom_tflops"] for r in candidates) / len(candidates) + ) + is_baseline = direction == NO_SWIZZLE_DIRECTION + label = ( + "matmul_abt(no-swizzle)" + if is_baseline + else f"matmul_abt(d={direction}, c={count})" + ) + plt.plot( + m_values, + series, + marker="o", + linestyle="-", + color=cmap(idx % 10), + alpha=1.0 if is_baseline else 0.7, + label=label, + ) + + plt.title(f"TFLOPS vs M (N={n}, K={k})") + plt.xlabel("M") + plt.ylabel("TFLOPS") + plt.xlim(left=0) + plt.ylim(bottom=0) + plt.grid(alpha=0.25) + plt.legend(fontsize=8) + plt.tight_layout() + out = plot_dir / f"flops_n{n}_k{k}.png" + plt.savefig(out, dpi=160) + plt.close() + print(f"Saved plot: {out}") + + plt.figure(figsize=(10, 5)) + ax_left = plt.gca() + cmap = plt.get_cmap("tab10") + + for idx, (direction, count) in enumerate(swizzles): + speedup_series = [] + for m in m_values: + candidates = [ + r + for r in chunk + if r["m"] == m + and r["swizzle_direction"] == direction + and r["swizzle_count"] == count + ] + if not candidates: + speedup_series.append(float("nan")) + else: + speedup_series.append( + sum(r["speedup_vs_no_swizzle"] for r in candidates) + / len(candidates) + ) + + is_baseline = direction == NO_SWIZZLE_DIRECTION + alpha = 1.0 if is_baseline else 0.7 + color = cmap(idx % 10) + base_label = ( + "no-swizzle baseline" if is_baseline else f"d={direction}, c={count}" + ) + speedup_label = f"speedup {base_label}" + + ax_left.plot( + m_values, + speedup_series, + marker="o", + linestyle="-", + color=color, + alpha=alpha, + label=speedup_label, + ) + + ax_left.set_title(f"Speed-up vs no-swizzle (N={n}, K={k})") + ax_left.set_xlabel("M") + ax_left.set_ylabel("Speed-up vs no-swizzle") + ax_left.set_xlim(left=0) + ax_left.set_ylim(bottom=0) + ax_left.grid(alpha=0.25) + ax_left.legend(fontsize=8) + plt.tight_layout() + ratio_out = plot_dir / f"ratio_n{n}_k{k}.png" + plt.savefig(ratio_out, dpi=160) + plt.close() + print(f"Saved plot: {ratio_out}") + + +def main(): + args = _parse_args() + base_dir = Path(__file__).resolve().parent + device = get_test_device() + torch.npu.set_device(device) + + m_list = _parse_int_list(args.m_list) + if args.warmup < 1 or args.repeat < 1: + raise ValueError("--warmup and --repeat must be positive integers.") + + lib_path = Path(args.lib) + if not lib_path.is_absolute(): + lib_path = base_dir / lib_path + if not lib_path.exists(): + raise FileNotFoundError(f"Kernel library not found: {lib_path}") + + csv_path = Path(args.csv) + if not csv_path.is_absolute(): + csv_path = base_dir / csv_path + csv_path.parent.mkdir(parents=True, exist_ok=True) + + plot_dir = Path(args.plot_dir) + if not plot_dir.is_absolute(): + plot_dir = base_dir / plot_dir + + matmul_abt = load_lib(str(lib_path)) + torch.manual_seed(0) + + rows = [] + swizzle_cases = _swizzle_cases() + total_cases = len(m_list) * len(SHAPES_NK) * len(swizzle_cases) + case_idx = 0 + + for n, k in SHAPES_NK: + for m in m_list: + alloc = args.warmup + args.repeat + a_list = [ + torch.randn(m, k, dtype=torch.float16, device=device) + for _ in range(alloc) + ] + b_list = [ + torch.randn(n, k, dtype=torch.float16, device=device) + for _ in range(alloc) + ] + c_ref = F.linear(a_list[0], b_list[0]) + torch.npu.synchronize() + + linear_time_us = _time_fn( + F.linear, a_list, b_list, args.warmup, args.repeat + ) + flops = 2.0 * m * n * k + linear_tflops = flops / linear_time_us / 1e6 + + print(f"\n(M,N,K)=({m},{n},{k}) F.linear={linear_tflops:.3f} TFLOPS") + + case_rows = [] + no_swizzle_time_us = None + no_swizzle_tflops = None + + for swizzle_direction, swizzle_count in swizzle_cases: + case_idx += 1 + + def _custom(a, b, _d=swizzle_direction, _c=swizzle_count): + return matmul_abt( + a, + b, + block_dim=BLOCK_DIM, + swizzle_direction=_d, + swizzle_count=_c, + ) + + c = _custom(a_list[0], b_list[0]) + torch.npu.synchronize() + max_absdiff = float((c - c_ref).abs().max().item()) + mean_absdiff = float((c - c_ref).abs().mean().item()) + custom_time_us = _time_fn( + _custom, a_list, b_list, args.warmup, args.repeat + ) + custom_tflops = flops / custom_time_us / 1e6 + flops_fraction_vs_linear = custom_tflops / linear_tflops + + if ( + swizzle_direction == NO_SWIZZLE_DIRECTION + and swizzle_count == NO_SWIZZLE_COUNT + ): + no_swizzle_time_us = custom_time_us + no_swizzle_tflops = custom_tflops + + case_rows.append( + { + "case_idx": case_idx, + "m": m, + "n": n, + "k": k, + "block_dim": BLOCK_DIM, + "swizzle_direction": swizzle_direction, + "swizzle_count": swizzle_count, + "linear_time_us": linear_time_us, + "linear_tflops": linear_tflops, + "custom_time_us": custom_time_us, + "custom_tflops": custom_tflops, + "flops_fraction_vs_linear": flops_fraction_vs_linear, + "max_absdiff": max_absdiff, + "mean_absdiff": mean_absdiff, + } + ) + + del a_list, b_list, c, c_ref + torch.npu.empty_cache() + + if no_swizzle_time_us is None or no_swizzle_tflops is None: + raise RuntimeError( + "No no-swizzle baseline result found " + f"(direction={NO_SWIZZLE_DIRECTION}, count={NO_SWIZZLE_COUNT})." + ) + + for record in case_rows: + record["no_swizzle_time_us"] = no_swizzle_time_us + record["no_swizzle_tflops"] = no_swizzle_tflops + record["speedup_vs_no_swizzle"] = ( + no_swizzle_time_us / record["custom_time_us"] + ) + progress_idx = record.pop("case_idx") + + print( + f" [{progress_idx:03d}/{total_cases}] " + f"d={record['swizzle_direction']} c={record['swizzle_count']} " + f"custom={record['custom_tflops']:.3f} TFLOPS " + f"frac_of_linear={record['flops_fraction_vs_linear']:.3f} " + f"speedup_vs_no_swizzle={record['speedup_vs_no_swizzle']:.3f}x " + f"mean_diff={record['mean_absdiff']:.3e}" + ) + rows.append(record) + + fieldnames = [ + "m", + "n", + "k", + "block_dim", + "swizzle_direction", + "swizzle_count", + "linear_time_us", + "linear_tflops", + "custom_time_us", + "custom_tflops", + "flops_fraction_vs_linear", + "no_swizzle_time_us", + "no_swizzle_tflops", + "speedup_vs_no_swizzle", + "max_absdiff", + "mean_absdiff", + ] + with csv_path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + print(f"\nSaved benchmark CSV: {csv_path}") + + _maybe_plot(rows, plot_dir) + + +if __name__ == "__main__": + main() diff --git a/examples/aot/matmul_optimization_guide/experimental/caller.cpp b/examples/aot/matmul_optimization_guide/experimental/caller.cpp new file mode 100644 index 00000000..07774763 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/experimental/caller.cpp @@ -0,0 +1,28 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "matmul.cpp" +#endif + +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, + void *stream, + uint8_t *x, + uint8_t *y, + uint8_t *z, + int M, + int N, + int K, + int swizzle_direction, + int swizzle_count) +{ + matmul_kernel_ABt<<>>( + reinterpret_cast(x), + reinterpret_cast(y), + reinterpret_cast(z), + static_cast(M), + static_cast(N), + static_cast(K), + static_cast(swizzle_direction), + static_cast(swizzle_count)); +} diff --git a/examples/aot/matmul_optimization_guide/experimental/compile.sh b/examples/aot/matmul_optimization_guide/experimental/compile.sh new file mode 100644 index 00000000..9eb80bca --- /dev/null +++ b/examples/aot/matmul_optimization_guide/experimental/compile.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +set -euo pipefail + +rm -f matmul.pto matmul.cpp matmul_kernel.so + +python ./matmul_builder.py > matmul.pto +ptoas matmul.pto -o matmul.cpp + +bisheng -fPIC -shared -xcce -O2 -std=c++17 \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -I"${ASCEND_TOOLKIT_HOME}/include" \ + -DKERNEL_CPP="\"matmul.cpp\"" \ + ./caller.cpp \ + -o ./matmul_kernel.so diff --git a/examples/aot/matmul_optimization_guide/experimental/matmul_builder.py b/examples/aot/matmul_optimization_guide/experimental/matmul_builder.py new file mode 100644 index 00000000..a807414c --- /dev/null +++ b/examples/aot/matmul_optimization_guide/experimental/matmul_builder.py @@ -0,0 +1,364 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + + +def build(): + M_TILE = 128 + K_QTILE = 64 + K_TILE = 256 + K_DTILE = 512 + N_FULL = 256 + N_HALF = 128 + + def meta_data(): + dtype = pto.float16 + acc_dtype = pto.float32 + ptr_type = pto.PtrType(dtype) + i32 = pto.int32 + tv_a = pto.TensorType(rank=2, dtype=dtype) + tv_b = pto.TensorType(rank=2, dtype=dtype) + tv_c = pto.TensorType(rank=2, dtype=dtype) + + tile_view_a = pto.SubTensorType(shape=[M_TILE, K_DTILE], dtype=dtype) + tile_view_b_256 = pto.SubTensorType(shape=[K_TILE, N_FULL], dtype=dtype) + tile_view_b_128 = pto.SubTensorType(shape=[K_TILE, N_HALF], dtype=dtype) + tile_view_c_256 = pto.SubTensorType(shape=[M_TILE, N_FULL], dtype=dtype) + tile_view_c_128 = pto.SubTensorType(shape=[M_TILE, N_HALF], dtype=dtype) + + b_l1_cfg = pto.TileBufConfig( + blayout="RowMajor", slayout="ColMajor", s_fractal_size=512 + ) + + tile_buf_a_l1 = pto.TileBufType( + shape=[M_TILE, K_DTILE], dtype=dtype, memory_space="MAT" + ) + tile_buf_b_l1_256 = pto.TileBufType( + shape=[K_TILE, N_FULL], dtype=dtype, memory_space="MAT", config=b_l1_cfg + ) + tile_buf_b_l1_128 = pto.TileBufType( + shape=[K_TILE, N_HALF], dtype=dtype, memory_space="MAT", config=b_l1_cfg + ) + tile_buf_a_l0 = pto.TileBufType( + shape=[M_TILE, K_QTILE], dtype=dtype, memory_space="LEFT" + ) + tile_buf_b_l0_256 = pto.TileBufType( + shape=[K_QTILE, N_FULL], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_b_l0_128 = pto.TileBufType( + shape=[K_QTILE, N_HALF], dtype=dtype, memory_space="RIGHT" + ) + tile_buf_c_256 = pto.TileBufType( + shape=[M_TILE, N_FULL], dtype=acc_dtype, memory_space="ACC" + ) + tile_buf_c_128 = pto.TileBufType( + shape=[M_TILE, N_HALF], dtype=acc_dtype, memory_space="ACC" + ) + + return { + "ptr_type": ptr_type, + "i32": i32, + "tv_a": tv_a, + "tv_b": tv_b, + "tv_c": tv_c, + "tile_view_a": tile_view_a, + "tile_view_b_256": tile_view_b_256, + "tile_view_b_128": tile_view_b_128, + "tile_view_c_256": tile_view_c_256, + "tile_view_c_128": tile_view_c_128, + "tile_buf_a_l1": tile_buf_a_l1, + "tile_buf_b_l1_256": tile_buf_b_l1_256, + "tile_buf_b_l1_128": tile_buf_b_l1_128, + "tile_buf_a_l0": tile_buf_a_l0, + "tile_buf_b_l0_256": tile_buf_b_l0_256, + "tile_buf_b_l0_128": tile_buf_b_l0_128, + "tile_buf_c_256": tile_buf_c_256, + "tile_buf_c_128": tile_buf_c_128, + } + + def swizzle_zn(li, m_loop, n_loop, cSwizzle, cSwizzleM1, c1, c2): + tile_block_loop = (m_loop + cSwizzleM1) // cSwizzle + tile_block_span = cSwizzle * n_loop + tile_block_idx = li // tile_block_span + in_tile_block_idx = li % tile_block_span + is_last_block = tile_block_idx == (tile_block_loop - c1) + n_row_tail = m_loop - cSwizzle * tile_block_idx + n_row = s.select(is_last_block, n_row_tail, cSwizzle) + m_idx = tile_block_idx * cSwizzle + (in_tile_block_idx % n_row) + n_idx = in_tile_block_idx // n_row + odd_block = (tile_block_idx % c2) == c1 + flipped_n_idx = n_loop - n_idx - c1 + n_idx = s.select(odd_block, flipped_n_idx, n_idx) + return m_idx, n_idx + + def swizzle_nz(li, m_loop, n_loop, cSwizzle, cSwizzleM1, c1, c2): + tile_block_loop = (n_loop + cSwizzleM1) // cSwizzle + tile_block_span = cSwizzle * m_loop + tile_block_idx = li // tile_block_span + in_tile_block_idx = li % tile_block_span + is_last_block = tile_block_idx == (tile_block_loop - c1) + n_col_tail = n_loop - cSwizzle * tile_block_idx + n_col = s.select(is_last_block, n_col_tail, cSwizzle) + m_idx = in_tile_block_idx // n_col + n_idx = tile_block_idx * cSwizzle + (in_tile_block_idx % n_col) + odd_block = (tile_block_idx % c2) == c1 + flipped_m_idx = m_loop - m_idx - c1 + m_idx = s.select(odd_block, flipped_m_idx, m_idx) + return m_idx, n_idx + + def level1_loop_mn_dynamic_tilesize( + n_tile: int, + b_view_type, + c_view_type, + b_l1_type, + b_l0_type, + c_type, + m_offset, + n_offset, + k_dtile_num, + li, + core_loop, + bid, + num_blocks, + tvA, + tvB, + tvC, + ): + c0 = const(0) + c1 = const(1) + c2 = const(2) + cKT = const(K_TILE) + cKD = const(K_DTILE) + cNT = const(n_tile) + + a_l1 = [pto.alloc_tile(tile_buf_a_l1), pto.alloc_tile(tile_buf_a_l1)] + b_l1 = [pto.alloc_tile(b_l1_type), pto.alloc_tile(b_l1_type)] + a_l0 = [pto.alloc_tile(tile_buf_a_l0), pto.alloc_tile(tile_buf_a_l0)] + b_l0 = [pto.alloc_tile(b_l0_type), pto.alloc_tile(b_l0_type)] + c_l0 = pto.alloc_tile(c_type) + + not_first_tile = li != bid + with pto.if_context(not_first_tile): + pto.wait_event("STORE_ACC", "MATMUL", event_id=0) + + sv_a0 = pto.slice_view( + tile_view_a, + source=tvA, + offsets=[m_offset, c0], + sizes=[const(M_TILE), cKD], + ) + pto.wait_event("MOV_M2L", "LOAD", event_id=0) + pto.load(sv_a0, a_l1[0]) + pto.record_event("LOAD", "MOV_M2L", event_id=0) + + for k_idx in pto.range(c0, k_dtile_num, c1): + k_offset = k_idx * cKD + is_curr0 = (k_idx % c2) == c0 + + def level2_loop_k(curr_id, next_id, a_curr, a_next): + is_first_k_tile = k_idx == c0 + + for h in range(2): + b_evt = 2 + h + h_off = const(h * K_TILE) + sv_b = pto.slice_view( + b_view_type, + source=tvB, + offsets=[k_offset + h_off, n_offset], + sizes=[cKT, cNT], + ) + + pto.wait_event("MOV_M2L", "LOAD", event_id=b_evt) + pto.load(sv_b, b_l1[h]) + pto.record_event("LOAD", "MOV_M2L", event_id=b_evt) + + for quarter in range(4): + phase = h * 4 + quarter + ping = phase & 1 + a_col = const(phase * K_QTILE) + b_row = const(quarter * K_QTILE) + + pto.wait_event("MATMUL", "MOV_M2L", event_id=ping) + if phase == 0: + pto.wait_event("LOAD", "MOV_M2L", event_id=curr_id) + + tile.extract(a_curr, c0, a_col, a_l0[ping]) + if phase == 7: + pto.record_event("MOV_M2L", "LOAD", event_id=curr_id) + + if quarter == 0: + pto.wait_event("LOAD", "MOV_M2L", event_id=b_evt) + + tile.extract(b_l1[h], b_row, c0, b_l0[ping]) + pto.record_event("MOV_M2L", "MATMUL", event_id=0) + + if quarter == 3: + pto.record_event("MOV_M2L", "LOAD", event_id=b_evt) + + pto.wait_event("MOV_M2L", "MATMUL", event_id=0) + if phase == 0: + pto.cond( + is_first_k_tile, + lambda: tile.matmul(a_l0[ping], b_l0[ping], c_l0), + lambda: tile.matmul_acc( + c_l0, a_l0[ping], b_l0[ping], c_l0 + ), + ) + else: + tile.matmul_acc(c_l0, a_l0[ping], b_l0[ping], c_l0) + + pto.record_event("MATMUL", "MOV_M2L", event_id=ping) + + with pto.if_context(k_idx + c1 < k_dtile_num): + sv_a_next = pto.slice_view( + tile_view_a, + source=tvA, + offsets=[m_offset, k_offset + cKD], + sizes=[const(M_TILE), cKD], + ) + pto.wait_event("MOV_M2L", "LOAD", event_id=next_id) + pto.load(sv_a_next, a_next) + pto.record_event("LOAD", "MOV_M2L", event_id=next_id) + + with pto.if_context(is_curr0, has_else=True) as branch: + level2_loop_k(0, 1, a_l1[0], a_l1[1]) + with branch.else_context(): + level2_loop_k(1, 0, a_l1[1], a_l1[0]) + + sv_c = pto.slice_view( + c_view_type, + source=tvC, + offsets=[m_offset, n_offset], + sizes=[const(M_TILE), cNT], + ) + pto.record_wait_pair("MATMUL", "STORE_ACC", event_id=0) + pto.store(c_l0, sv_c) + + with pto.if_context(li + num_blocks < core_loop): + pto.record_event("STORE_ACC", "MATMUL", event_id=0) + + @to_ir_module(meta_data=meta_data) + def matmul_kernel_ABt( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + m_i32: "i32", + n_i32: "i32", + k_i32: "i32", + swizzle_direction_i32: "i32", + swizzle_count_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + c2 = const(2) + c128 = const(M_TILE) + c256 = const(N_FULL) + c128n = const(N_HALF) + c512 = const(K_DTILE) + + m_total = s.index_cast(m_i32) + n_total = s.index_cast(n_i32) + k_total = s.index_cast(k_i32) + swizzle_direction = s.index_cast(swizzle_direction_i32) + swizzle_count = s.index_cast(swizzle_count_i32) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + cSwizzle = s.select(swizzle_count > c0, swizzle_count, c1) + cSwizzleM1 = cSwizzle - c1 + + n_loop = (n_total + c256 - c1) // c256 + m_loop = m_total // c128 + core_loop = n_loop * m_loop + k_dtile_num = k_total // c512 + + tvA = pto.as_tensor( + tv_a, ptr=a_ptr, shape=[m_total, k_total], strides=[k_total, c1] + ) + tvB = pto.as_tensor( + tv_b, + ptr=b_ptr, + shape=[k_total, n_total], + strides=[c1, k_total], + layout="DN", + ) + tvC = pto.as_tensor( + tv_c, ptr=c_ptr, shape=[m_total, n_total], strides=[n_total, c1] + ) + + pto.record_event("MATMUL", "MOV_M2L", event_id=[0, 1]) + pto.record_event("MOV_M2L", "LOAD", event_id=[0, 1, 2, 3]) + + def level1_loop_mn(m_offset, n_offset, li): + # TODO: make a simpler version that only uses full-tile (256) branch, and reduce the types needed in meta_data + n_tile_size = s.select(n_offset + c256 > n_total, c128n, c256) + shared_args = [ + m_offset, + n_offset, + k_dtile_num, + li, + core_loop, + bid, + num_blocks, + tvA, + tvB, + tvC, + ] + with pto.if_context(n_tile_size == c256, has_else=True) as branch: + level1_loop_mn_dynamic_tilesize( + N_FULL, + tile_view_b_256, + tile_view_c_256, + tile_buf_b_l1_256, + tile_buf_b_l0_256, + tile_buf_c_256, + *shared_args, + ) + with branch.else_context(): + level1_loop_mn_dynamic_tilesize( + N_HALF, + tile_view_b_128, + tile_view_c_128, + tile_buf_b_l1_128, + tile_buf_b_l0_128, + tile_buf_c_128, + *shared_args, + ) + + for li in pto.range(bid, core_loop, num_blocks): + with pto.if_context( + swizzle_direction == c0, has_else=True + ) as c0_branch: + m_idx, n_idx = swizzle_zn( + li, m_loop, n_loop, cSwizzle, cSwizzleM1, c1, c2 + ) + level1_loop_mn(m_idx * c128, n_idx * c256, li) + + with c0_branch.else_context(): + with pto.if_context( + swizzle_direction == c1, has_else=True + ) as c1_branch: + m_idx, n_idx = swizzle_nz( + li, m_loop, n_loop, cSwizzle, cSwizzleM1, c1, c2 + ) + level1_loop_mn(m_idx * c128, n_idx * c256, li) + + with c1_branch.else_context(): + # Default linear mapping, used when swizzle_direction is not 0/1. + m_idx = li // n_loop + n_idx = li % n_loop + level1_loop_mn(m_idx * c128, n_idx * c256, li) + + pto.wait_event("MOV_M2L", "LOAD", event_id=3) + pto.wait_event("MOV_M2L", "LOAD", event_id=2) + pto.wait_event("MOV_M2L", "LOAD", event_id=1) + pto.wait_event("MOV_M2L", "LOAD", event_id=0) + pto.wait_event("MATMUL", "MOV_M2L", event_id=0) + pto.wait_event("MATMUL", "MOV_M2L", event_id=1) + + return matmul_kernel_ABt + + +if __name__ == "__main__": + print(build()) diff --git a/examples/aot/matmul_optimization_guide/experimental/run_matmul.py b/examples/aot/matmul_optimization_guide/experimental/run_matmul.py new file mode 100644 index 00000000..54a2339f --- /dev/null +++ b/examples/aot/matmul_optimization_guide/experimental/run_matmul.py @@ -0,0 +1,211 @@ +import ctypes +import os +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +import torch_npu + +from ptodsl.test_util import get_test_device + + +BLOCK_DIM_LIST = [1, 20, 24] +SWIZZLE_DIRECTION_LIST = [0, 1] +SWIZZLE_COUNT_LIST = [1, 3, 5] +M_LIST = [128 * i for i in range(1, 37, 4)] # 128, ..., 4224 +SHAPES_NK = [ + (4096, 4096), + (8192, 8192), + (16384, 16384), +] +MAX_ABSDIFF_THRESHOLD = 0.5 +MEAN_ABSDIFF_THRESHOLD = 1e-4 + + +@dataclass +class CaseResult: + m: int + n: int + k: int + block_dim: int + swizzle_direction: int + swizzle_count: int + max_absdiff: float + mean_absdiff: float + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path): + lib = ctypes.CDLL(os.path.abspath(lib_path)) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ] + lib.call_kernel.restype = None + + def matmul_abt( + a, + b, + *, + block_dim=24, + swizzle_direction=1, + swizzle_count=3, + stream_ptr=None, + ): + if a.ndim != 2 or b.ndim != 2: + raise ValueError("matmul_abt expects 2D tensors: a[M,K], b[N,K]") + if a.shape[1] != b.shape[1]: + raise ValueError( + f"K mismatch: a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)}" + ) + if a.dtype != torch.float16 or b.dtype != torch.float16: + raise ValueError("matmul_abt currently supports float16 inputs only") + + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + + m = int(a.shape[0]) + k = int(a.shape[1]) + n = int(b.shape[0]) + c = torch.empty((m, n), device=a.device, dtype=a.dtype) + + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(a), + torch_to_ctypes(b), + torch_to_ctypes(c), + m, + n, + k, + swizzle_direction, + swizzle_count, + ) + return c + + return matmul_abt + + +def run_case(matmul_abt, a, b, c_ref, *, block_dim, swizzle_direction, swizzle_count): + c = matmul_abt( + a, + b, + block_dim=block_dim, + swizzle_direction=swizzle_direction, + swizzle_count=swizzle_count, + ) + torch.npu.synchronize() + result = CaseResult( + m=int(a.shape[0]), + n=int(b.shape[0]), + k=int(a.shape[1]), + block_dim=block_dim, + swizzle_direction=swizzle_direction, + swizzle_count=swizzle_count, + max_absdiff=float((c - c_ref).abs().max().item()), + mean_absdiff=float((c - c_ref).abs().mean().item()), + ) + del c + torch.npu.empty_cache() + return result + + +def test_matmul(): + device = get_test_device() + torch.npu.set_device(device) + matmul_abt = load_lib("./matmul_kernel.so") + + torch.manual_seed(0) + checked_cases = 0 + global_worst = None + + for m in M_LIST: + for n, k in SHAPES_NK: + a = torch.randn(m, k, dtype=torch.float16, device=device) + b = torch.randn(n, k, dtype=torch.float16, device=device) + c_ref = F.linear(a, b) + torch.npu.synchronize() + + shape_worst = None + for block_dim in BLOCK_DIM_LIST: + for swizzle_direction in SWIZZLE_DIRECTION_LIST: + for swizzle_count in SWIZZLE_COUNT_LIST: + result = run_case( + matmul_abt, + a, + b, + c_ref, + block_dim=block_dim, + swizzle_direction=swizzle_direction, + swizzle_count=swizzle_count, + ) + checked_cases += 1 + + if ( + shape_worst is None + or result.max_absdiff > shape_worst.max_absdiff + or ( + result.max_absdiff == shape_worst.max_absdiff + and result.mean_absdiff > shape_worst.mean_absdiff + ) + ): + shape_worst = result + + if ( + global_worst is None + or result.max_absdiff > global_worst.max_absdiff + or ( + result.max_absdiff == global_worst.max_absdiff + and result.mean_absdiff > global_worst.mean_absdiff + ) + ): + global_worst = result + + del a, b, c_ref + torch.npu.empty_cache() + + print( + f"(m, n, k)=({m}, {n}, {k}) " + f"worst(block_dim, swizzle_direction, swizzle_count)=" + f"({shape_worst.block_dim}, {shape_worst.swizzle_direction}, " + f"{shape_worst.swizzle_count}) " + f"max_absdiff={shape_worst.max_absdiff:.6f} " + f"mean_absdiff={shape_worst.mean_absdiff:.6f}" + ) + + print(f"checked_cases={checked_cases}") + print( + "global_worst " + f"max_absdiff={global_worst.max_absdiff:.6f} " + f"mean_absdiff={global_worst.mean_absdiff:.6f} " + f"at (m, n, k, block_dim, swizzle_direction, swizzle_count)=" + f"({global_worst.m}, {global_worst.n}, {global_worst.k}, " + f"{global_worst.block_dim}, {global_worst.swizzle_direction}, " + f"{global_worst.swizzle_count})" + ) + + if global_worst.max_absdiff > MAX_ABSDIFF_THRESHOLD: + raise AssertionError( + f"max_absdiff {global_worst.max_absdiff:.6f} exceeds " + f"threshold {MAX_ABSDIFF_THRESHOLD:.6f}" + ) + if global_worst.mean_absdiff > MEAN_ABSDIFF_THRESHOLD: + raise AssertionError( + f"mean_absdiff {global_worst.mean_absdiff:.6f} exceeds " + f"threshold {MEAN_ABSDIFF_THRESHOLD:.6f}" + ) + + +if __name__ == "__main__": + test_matmul() diff --git a/examples/aot/matmul_optimization_guide/fig/cachehit_N16384.png b/examples/aot/matmul_optimization_guide/fig/cachehit_N16384.png new file mode 100644 index 00000000..cddf5c6a Binary files /dev/null and b/examples/aot/matmul_optimization_guide/fig/cachehit_N16384.png differ diff --git a/examples/aot/matmul_optimization_guide/fig/cachehit_N16384_swizzle.png b/examples/aot/matmul_optimization_guide/fig/cachehit_N16384_swizzle.png new file mode 100644 index 00000000..00fe5f3e Binary files /dev/null and b/examples/aot/matmul_optimization_guide/fig/cachehit_N16384_swizzle.png differ diff --git a/examples/aot/matmul_optimization_guide/fig/cachehit_N4096.png b/examples/aot/matmul_optimization_guide/fig/cachehit_N4096.png new file mode 100644 index 00000000..7d67c6e0 Binary files /dev/null and b/examples/aot/matmul_optimization_guide/fig/cachehit_N4096.png differ diff --git a/examples/aot/matmul_optimization_guide/fig/flops_step1_baseline.png b/examples/aot/matmul_optimization_guide/fig/flops_step1_baseline.png new file mode 100644 index 00000000..7428f65d Binary files /dev/null and b/examples/aot/matmul_optimization_guide/fig/flops_step1_baseline.png differ diff --git a/examples/aot/matmul_optimization_guide/fig/flops_step2_doublebuf.png b/examples/aot/matmul_optimization_guide/fig/flops_step2_doublebuf.png new file mode 100644 index 00000000..e89485f3 Binary files /dev/null and b/examples/aot/matmul_optimization_guide/fig/flops_step2_doublebuf.png differ diff --git a/examples/aot/matmul_optimization_guide/fig/flops_step3_swizzle.png b/examples/aot/matmul_optimization_guide/fig/flops_step3_swizzle.png new file mode 100644 index 00000000..a04fbbda Binary files /dev/null and b/examples/aot/matmul_optimization_guide/fig/flops_step3_swizzle.png differ diff --git a/examples/aot/matmul_optimization_guide/fig/flops_step4_manual_pipeline.png b/examples/aot/matmul_optimization_guide/fig/flops_step4_manual_pipeline.png new file mode 100644 index 00000000..5e29a885 Binary files /dev/null and b/examples/aot/matmul_optimization_guide/fig/flops_step4_manual_pipeline.png differ diff --git a/examples/aot/matmul_optimization_guide/fig/pipeline_N1024_baseline.png b/examples/aot/matmul_optimization_guide/fig/pipeline_N1024_baseline.png new file mode 100644 index 00000000..98fc5246 Binary files /dev/null and b/examples/aot/matmul_optimization_guide/fig/pipeline_N1024_baseline.png differ diff --git a/examples/aot/matmul_optimization_guide/fig/pipeline_N1024_doublebuf.png b/examples/aot/matmul_optimization_guide/fig/pipeline_N1024_doublebuf.png new file mode 100644 index 00000000..84291314 Binary files /dev/null and b/examples/aot/matmul_optimization_guide/fig/pipeline_N1024_doublebuf.png differ diff --git a/examples/aot/matmul_optimization_guide/mamtul_optim_guide_zh.md b/examples/aot/matmul_optimization_guide/mamtul_optim_guide_zh.md new file mode 100644 index 00000000..23a502c1 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/mamtul_optim_guide_zh.md @@ -0,0 +1,344 @@ +# 从零手搓昇腾Matmul,用100行Python追平CANN主线库性能
(基于 PTO-ISA 的逐步优化指南) + +For English version see [matmul_optim_guide.md](./matmul_optim_guide.md) + +- 日期:2026/03/12 + +# 目录 + +- [写作动机](#motivation) +- [第 0 步:给CUDA/Triton用户的NPU编程速通](#step-0-npu-programming-crash-course-for-cudatriton-programmers) + - [NPU kernel launch行为](#typical-kernel-launch-syntax) + - [Software pipelining,自动vs手动](#auto-vs-manual-software-pipelining) +- [第 1 步:功能正确的基础版本](#step-1-functionally-correct-naive-version) +- [第 2 步:Double buffering](#step-2-double-buffering) +- [第 3 步:通过 "Swizzling" 提升 L2 cache 复用](#step-3-swizzling-for-l2-cache-reuse) +- [第 4 步:(可选)手动 software pipelining](#step-4-optional-manual-software-pipelining) +- [附录 A:PTO-DSL 语法说明](#appendix-a-pto-dsl-syntax-note) +- [附录 B:NPU profiler 使用方法](#appendix-b-using-npu-profiler) + +**复现本文全部结果**,见 [README.md](./README.md) 里的命令。 + + +# 写作动机 + +本文是NPU版本的“Matmul算子逐步优化实录”。这类文章在友商GPU十分流行(比如[这篇A100的](https://siboehm.com/articles/22/CUDA-MMM)和[这篇H100的](https://cudaforfun.substack.com/p/outperforming-cublas-on-h100-a-worklog)),但在我司的NPU上似乎还没有过公开的“从零手搓”教程。 + +我们会逐步把一个基于**约100行Python DSL**的算子优化到持平主线库的性能。对照的性能基线是NPU上的`torch.matmul`,内部调用[aclnnMatmul](https://www.hiascend.com/document/detail/zh/canncommercial/850/API/aolapi/context/ops-nn/aclnnMatmul.md)(NPU的“cuBLAS平替”),实现方式为[上万行的AscendC代码](https://gitcode.com/cann/ops-nn/tree/v8.5.0/matmul/mat_mul_v3/op_kernel)。 + +本教程的代码坚持:**极简、易于魔改、不带黑盒模板封装**,只提炼**少数最关键的**性能优化点。还有些更全面的、对corner case考虑更细致的Matmul实现例如[Catlass的矩阵乘模板总结](https://gitcode.com/cann/catlass/blob/v1.4.0/docs/contents/advanced/matmul_template_summary.md)和[AscendC的Matmul性能优化策略总览](https://www.hiascend.com/document/detail/zh/canncommercial/850/opdevg/Ascendcopdevg/atlas_ascendc_best_practices_10_10006.html),把大量优化都藏在了模板和封装里,适合作为后续进阶材料。 + + +# 第 0 步:给 CUDA/Triton 用户的 NPU 编程速通 + +(如果你已经写过NPU算子,可快速略过本节) + + +## NPU kernel launch行为 + +NPU上[SPMD](https://en.wikipedia.org/wiki/Single_program,_multiple_data)风格的kernel看起来和CUDA/Triton语法**似乎很像**: +- 内置变量`block_idx`和`block_num`用于每个core的地址offset计算 -- [示例](https://github.com/huawei-csl/pto-dsl/blob/b9b0c4abdcb84b84db53f27ffcb4ce8aa1b67316/examples/jit/add_dynamic_multicore/run_add.py) +- CUDA画风的`kernel_name<<>>(args)`kernel launch方式 -- [示例](https://github.com/huawei-csl/pto-dsl/blob/b9b0c4abdcb84b84db53f27ffcb4ce8aa1b67316/examples/aot/elementwise/add_dynamic_multicore/caller.cpp#L11) + +其实二者有个关键区别:NPU算子的写法基本都属于CUDA术语里的["persistent kernels"](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html),也就是`block_dim`等于硬件的核数,而不是随着输入数据size增长。 + +例如这个[基于PTO的动态shape向量相加](https://github.com/huawei-csl/pto-dsl/blob/b9b0c4abdcb84b84db53f27ffcb4ce8aa1b67316/examples/jit/add_dynamic_multicore/run_add.py#L46-L100):每个core不仅自己算好global memory offset,计算的循环迭代次数也会[随着动态的输入数据size而增加](https://github.com/huawei-csl/pto-dsl/blob/b9b0c4abdcb84b84db53f27ffcb4ce8aa1b67316/examples/jit/add_dynamic_multicore/run_add.py#L83)。这和常规的(非“persistent”)CUDA/Triton kernel 不一样。比如 [Triton vector add](https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html#compute-kernel) 设定 `grid = (ceil_div(n_elements, BLOCK_SIZE),)`,用launch时动态计算的`block_dim`匹配动态input size;而我们大多数的NPU kernel(不管基于PTO、AscendC、CCE 还是其他框架)通常都是 `grid = (num_cores,)`。 + +(在NPU上,大于核数的`block_dim`在简单场景能跑通,但Cube-Vector核间同步容易出bug。而且`block_dim >= 65536`会溢出,远小于CUDA的`maxGridSize`。我们遇过这个bug,通过切回“persistent-kernel”写法[修好了](https://github.com/huawei-csl/pto-kernels/pull/39)) + + +## Software pipelining,自动vs手动 + +NPU的片上缓存为[scratchpad memory](https://en.wikipedia.org/wiki/Scratchpad_memory),而非硬件管理的cache。所以要避免[data hazards](https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Data_hazards)需要开发者或编译器正确地使用[set_flag & wait_flag 接口](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/API/cceintrinsicapi/cceapi_0106.html),本质上是基于 [binary semaphore](https://en.wikipedia.org/wiki/Semaphore_(programming)#Producer%E2%80%93consumer_problem) 的同步机制。CUDA里最接近的是[`cp.async`+`wait`那一套](https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/async-copies.html)。可以参考这个[基于PTO-ISA手动同步的vector add示例](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/demos/torch_jit/add/add_custom.cpp#L78-L115)。对更复杂的融合算子如[FlashAttention](https://github.com/PTO-ISA/pto-isa/tree/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/kernels/manual/common/flash_atten),思考手动同步、software pipelining 和 prefetching, 对算子开发人员过于烧脑。 + +为了解决这个痛点,[PTO-DSL](https://github.com/huawei-csl/pto-dsl) 提供了自动同步,内部由基于[PTO MLIR dialect](https://github.com/zhangstevenunity/PTOAS/blob/v0.9/docs/PTO_IR_manual.md)的[InsertSync pass](https://github.com/zhangstevenunity/PTOAS/tree/v0.9/lib/PTO/Transforms/InsertSync)实现。对用户而言,算子代码看起来还是“串行的”(在pipelining意义上),写起来更接近Triton/CuTile的手感。 + + +# 第 1 步:功能正确的基础版本 + +根据[NPU硬件架构](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/opdevg/Ascendcopdevg/atlas_ascendc_10_0008.html),要完成matmul需要的数据搬运路径是: +- `GM`(global memory)-> `L1` -> `L0`(左/右操作数对应`L0A`/`L0B`)-> `Cube core` -> `L0C` -> `GM` + +读取到片上的tile 大小(算法参数)受到 L1/L0 SRAM容量(硬件参数)的约束。要查询[硬件参数规格](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/opdevg/Ascendcopdevg/atlas_ascendc_10_0011.html),可以在任意安装了CANN的环境里看文件 `${ASCEND_HOME_PATH}/arm64-linux/data/platform_config/*.ini`: + +```bash +grep -A 9 "AICoreSpec" ${ASCEND_HOME_PATH}/arm64-linux/data/platform_config/Ascend910B2.ini +``` + +输出: + +``` +[AICoreSpec] +... +l0_a_size=65536 # 64 KiB +l0_b_size=65536 # 64 KiB +l0_c_size=131072 # 128 KiB +l1_size=524288 # 512 KiB +``` + +考虑经典的[分块矩阵乘法](https://en.wikipedia.org/wiki/Loop_nest_optimization#Example:_matrix_multiplication)。任意shape的`C = A @ B`运算会被分解为tile级别操作:`A_tile = A[i1:i2,k1:k2]`、`B_tile = B[k1:k2,j1:j2]`、`C_tile = C[i1:i2,j1:j2]`,保证每个 tile 能放进 SRAM。结合上面的 SRAM 信息,这里选择: +- `A_tile` 在 `L1` 上为 `[128 x 512]`,占 128 KiB(fp16) +- `B_tile` 在 `L1` 上为 `[256 x 256]`,占 128 KiB(fp16) +- `A_tile` 在 `L0A` 上为 `[128 x 64]`,占 16 KiB(fp16) +- `B_tile` 在 `L0B` 上为 `[64 x 256]`,占 32 KiB(fp16) +- `C_tile` 在 `L0C` 上为 `[128 x 256]`,占 128 KiB(fp32 accumulation) +- Cube unit执行size为`(M, N, K) = (128, 256, 64)` 的 [`TMATMUL`](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TMATMUL.md) 指令,输入为 `L0A` 和 `L0B`,输出为 `L0C`。 + +为啥选这组参数: +- 这是[ATB 库的matmul](https://gitcode.com/cann/ascend-transformer-boost/blob/br_release_cann_8.5.0_20260527/src/kernels/kernels/matmul/pp_matmul_f16_kernel/op_kernel/pp_matmul.cce?init=initTree) 的常用tiling方案之一。也有其他很多可行组合,只要 buffer 能装下。 +- L0上更大的tile有利于Cube unit达到更高的FLOPS。比如128 x 128比32 x 32的FLOPs高好几倍。完整支持的 matmul shape 和 dtype 参见[`Mmad`指令](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/API/ascendcopapi/atlasascendc_api_07_0249.html)。 +- `L1`、`L0A`、`L0B` 都预留了 >=50% 空间没用,留给下一步double-buffering用。 + +[step1_baseline_numpy_sim.py](./step1_baseline_numpy_sim.py) 提供了“NumPy 仿真代码”帮助理解算法逻辑。这里用的算法是最基础的 “split-MN matmul”,每个 core 输出自己的 `C_tile = C[i1:i2,j1:j2]`。(Split-K 和 Stream-K等变种留到以后再说)。算法核心逻辑如下: +- 顶层循环 `for li in range(core_loop):` 来自前文的“persistent kernel”要求。我们不做双层“行列循环”,而是把它们合并成单层 `core_loop = n_loop * m_loop`。这样每次迭代都可以独立分配给不同 core,并独立完成一个 `C_tile`。 +- 然后只需沿内层 K 维做累加: + - 第二层 `for k_idx in range(k_dtile_num)` 对应 “GM - L1 级”迭代:当前 `L1` tile 被 matmul 用完后,再从 `GM` 加载下一个。 + - 第三层 `for phase in range(8):` 对应 “L1 - L0 级”迭代:当前 `L0` tile 被 matmul 用完后,再从 `L1` 加载下一个。 + - 由于 `L1` tile 和 `L0` tile 的尺寸比固定,第三层循环可以**静态展开**。因为 `L0` tile 小于 `L1` tile,每次 “L1 级”迭代会对应多个 “L0 级”迭代。 + +接着把NumPy翻译成等价的PTO-DSL,见 [step1_baseline.py](./step1_baseline.py) 和 [common_utils.py](./common_utils.py)。代码结构几乎一一对应,只是把NumPy API换成了NPU特有的API: +- `pto.load`([`TLOAD`](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TLOAD.md))做 `GM`->`L1` +- `tile.extract`([`TEXTRACT`](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TEXTRACT.md))做 `L1`->`L0A`、`L1`->`L0B` +- `tile.matmul`/`tile.matmul_acc`([`TMATMUL`](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TMATMUL.md)/[`TMATMUL_ACC`](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TMATMUL_ACC.md))做 `L0` 上的计算 +- `pto.store`([`TSTORE`](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TSTORE.md))做 `L0C`->`GM` +- 静态loop unrolling用 Python 原生 `for i in range()`;run-time动态循环用 `for i in pto.range()`。`if`/`else` 也同理类似。 + +更详细的DSL语法说明见 [附录 A:PTO-DSL 语法说明](#appendix-a-pto-dsl-syntax-note)。 + +这个80行的算子实现可以在NPU跑出正确的数值结果,但性能只有 `torch.matmul` 的 50% 左右。下一节追上性能差距。 + +![image info](./fig/flops_step1_baseline.png) + + +# 第 2 步:Double buffering + +先用 `msprof op simulator` 测试前一版 kernel: + +```bash +msprof op simulator --aic-metrics=PipeUtilization \ + --kernel-name="_Z28matmul_kernel_step1_baselinePDhS_S_iii_mix_aic" \ + --output="msprof_res" --launch-count=5 \ + python ./run_matmul.py --variant step1-baseline +``` + +(更多 profiler 用法见 [附录 B:NPU profiler 使用方法](#appendix-b-using-npu-profiler)) + +可以看到 Cube core 有 50% 时间在空转: + +![image info](./fig/pipeline_N1024_baseline.png) + +做了Double buffering(本质是用空间换时间),可以把计算和数据传输尽量重叠: + +![image info](./fig/pipeline_N1024_doublebuf.png) + +完整代码见 [./step2_doublebuffer.py](./step2_doublebuffer.py)。 + +Profile改进后的算子: + +
+ +```bash +msprof op simulator --aic-metrics=PipeUtilization \ + --kernel-name="_Z26matmul_kernel_ABt_autosyncPDhS_S_iii_mix_aic" \ + --output="msprof_res" --launch-count=5 \ + python ./run_matmul.py --variant step2-doublebuffer +``` + +
+ +唯一的代码改动是在 `L1` 和 `L0` 上给 `A_tile`、`B_tile` 各开 2 份 buffer: + +```python +a_l1 = [pto.alloc_tile(tile_buf_a_l1), pto.alloc_tile(tile_buf_a_l1)] +b_l1 = [pto.alloc_tile(tile_buf_b_l1), pto.alloc_tile(tile_buf_b_l1)] +a_l0 = [pto.alloc_tile(tile_buf_a_l0), pto.alloc_tile(tile_buf_a_l0)] +b_l0 = [pto.alloc_tile(tile_buf_b_l0), pto.alloc_tile(tile_buf_b_l0)] +``` + +然后在迭代之间交替使用 "odd" / "even" 两块 buffer。 + +优化效果显著,对于中小规模的矩阵,FLOPs 基本翻倍: +![image info](./fig/flops_step2_doublebuf.png) + +但矩阵一旦变大(比如 16384x16384),FLOPs 会**突然跌落**。原因是 NPU 的 L2 cache 装不下整块矩阵,开始出现 cache eviction。 + +查看 L2 cache 大小: + +```bash +grep -A 8 "SoCInfo" ${ASCEND_HOME_PATH}/arm64-linux/data/platform_config/Ascend910B2.ini +``` + +输出: + +``` +[SoCInfo] +ai_core_cnt=24 +cube_core_cnt=24 +vector_core_cnt=48 +ai_cpu_cnt=6 +memory_type= +memory_size=68719476736 # 64 GiB +l2_type=0 +l2_size=201326592 # 192 MiB +``` + +8192x8192 矩阵(float16 下 64 MiB)小于 L2;而16384x16384(float16 下 256 MiB)大于 L2,所以后者的性能显著更差。 + +`910B4` 的 HBM 和 L2 都是 `910B2` 的一半(因此更小矩阵就会触发cache eviction): + +```bash +grep -A 8 "SoCInfo" ${ASCEND_HOME_PATH}/arm64-linux/data/platform_config/Ascend910B4.ini +``` + +``` +[SoCInfo] +ai_core_cnt=20 +cube_core_cnt=20 +vector_core_cnt=40 +ai_cpu_cnt=6 +memory_type= +memory_size=34359738368 # 32 GiB +l2_type=0 +l2_size=100663296 # 96 MiB +``` + + +# 第 3 步:通过 "Swizzling" 提升 L2 cache 复用 + +提高多核之间的L2 cache复用,“swizzling”是最常用的技巧,对NPU和GPU都适用。下图借自 [Triton matmul讲解](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations): +Grouped vs row-major ordering (from Triton) + +这张图可以这样理解:假设第一轮迭代有9个核各自在算 `C` 的一个子块(黄色标记区域,0~8 是 core id)。在朴素的 "row-major ordering" 下,完整 B 矩阵(假设大于 L2)要频繁从 global memory 读取;而用 "grouped ordering" 后,global memory traffic大幅下降。 + +[step3_swizzle.py](./step3_swizzle.py) 在 step2 基础上只加了一个 10 行 swizzle 函数 `swizzle_nz`,其余代码完全不动。[step3_swizzle_numpy_sim.py](./step3_swizzle_numpy_sim.py) 直观解释了swizzle对循环下标的影响。这个具体swizzle方案来自[catlass的block swizzle](https://gitcode.com/cann/catlass/blob/v1.4.0/include/catlass/gemm/block/block_swizzle.hpp)([讲解文档](https://gitcode.com/cann/catlass/blob/v1.4.0/docs/contents/advanced/swizzle_explanation.md))。 + +(For GPU熟练工: 这个下标重映射类似[DeepGEMM的scheduler](https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/deep_gemm/include/deep_gemm/common/scheduler.cuh),重排每个SM的数据分配和循环顺序) + +只加这 10 行 swizzle,FLOPs 就有明显提升,达到了`torch.matmul`的90%! + +![image info](./fig/flops_step3_swizzle.png) + +为了确认是L2 cache在起作用,用`msprof op`检查cache hit: + +```bash +msprof op \ + --aic-metrics=Occupancy,Roofline,Default,L2Cache,PipeUtilization,MemoryL0 \ + --kernel-name="_Z26matmul_kernel_ABt_autosyncPDhS_S_iii_mix_aic" \ + --output="msprof_res" --launch-count=5 \ + python ./run_matmul.py --variant step3-swizzle +``` + +对4096x4096小矩阵,即使不用swizzled loop order,L2 hit就很高(97.88%): + +cachehit_N4096 + +对16384x16384大矩阵,由于超过了L2 size,不swizzle的话L2 hit低到了30.9%: + +cachehit_N16384 + +加了swizzling后,16384x16384场景的L2 hit 提升到93.72%了: + +cachehit_N16384_swizzle + + +# 第 4 步:(可选)手动 software pipelining + +最后这 10% 的性能差距,可以通过 [./step4_manual_pipelining.py](./step4_manual_pipelining.py) 里的手动排流水压榨出来。 + +![image info](./fig/flops_step4_manual_pipeline.png) + +即便做了手动同步,代码也只是从 ~100 行增长到 ~150 行 Python,仍然比CANN算子库的代码短很多。如何手工排流水超出了本文的讲解范围。我们正在 [推进相关 compile pass](https://github.com/zhangstevenunity/PTOAS/issues/226),争取让编译器自动同步性能持平手排。 + + +# 附录 A:PTO-DSL 语法说明 + +当前的 [PTO-DSL package](https://github.com/huawei-csl/pto-dsl/tree/b9b0c4abdcb84b84db53f27ffcb4ce8aa1b67316/ptodsl) 只是在 PTO dialect 的 [MLIR Python bindings](https://mlir.llvm.org/docs/Bindings/Python/)上做了很薄的封装。整个DSL包只有 **约1000行Python**(可以用 `cd ptodsl && find . -name "*.py" | xargs wc -l` 自行确认) + +为了在开发阶段维持一个简单好改的框架,我们目前**不**做Python AST parsing / AST rewriting。因此,所有 Python 原生语法(包括`if`/`for` 控制流、Python class、iterator 等)都按普通Python代码执行。这点和其他Python DSL的做法不太相同:有的是纯 AST 路线(如 Triton、CuTile),有的是 AST+tracing 混合路线(如 Tilelang、CuteDSL),它们 *可能会,也可能不会* 把原生 `if`/`range` rewrite成特殊 IR builder(可参考 [CuteDSL 的复杂规则](https://github.com/Dao-AILab/quack/blob/v0.3.2/docs/dsl_control_flow.rst))。当前 PTO-DSL frontend 是纯 Python tracing,更接近 JAX 的思路。 + +**用户只要记住:** run-time动态控制流全在 `pto` 命名空间里(例如 `pto.range`,会在 IR 中生成 [MLIR structured control flow](https://mlir.llvm.org/docs/Dialects/SCFDialect/));而 Python 原生控制流是在 build-time 就求值完成的。 + +常见场景: + +- **Python `for ... in range(...)`** + - 在生成 IR 前执行(build-time) + - 常用于编译期 metaprogramming / unrolling +- **`for ... in pto.range(...)`** + - 生成 MLIR `scf.for` loop + - 在 kernel run-time 动态执行 +- **Python `if condition:`** + - condition 在 build-time 由 Python 求值 + - 分支在生成 IR 前就被选定 +- **`with pto.if_context(cond):` / `pto.cond(...)`** + - 在 IR 中生成 runtime `scf.if` + - condition 在 kernel 运行时求值 + +**示例 1:`pto.range`(IR 里的 runtime loop)** + +来自 `step1_baseline.py`: + +```python +for li in pto.range(bid, core_loop, num_blocks): + ... +``` + +这**不是**普通Python循环。在 PTO-DSL 里,`pto.range` 是一个 IR-builder primitive(见 `control_flow.py`),会创建 `scf.ForOp` 并返回 induction-variable。 + +实际效果:会以 loop 形式保留在 IR 里(不会被 Python 展开) + +**示例 2:Python `range`(build-time unrolling)** + +来自 `step1_baseline.py`: + +```python +for phase in range(8): + ... +``` + +这个 loop 在构建 IR 时由 Python 执行,所以通常会在 IR 中生成 8 份重复代码区域。 + +类比C++编程: +- 概念上接近 compile-time codegen / metaprogramming +- 当 loop bound 是小常量时非常实用 + +**示例 3:Python `if` vs `pto.if_context`** + +来自 `step1_baseline.py`: + +```python +if phase == 0: + with pto.if_context(is_first_k_tile, has_else=True) as branch: + tile.matmul(a_l0, b_l0, c_l0) + with branch.else_context(): + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) +else: + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) +``` + +理解方式: +- `if phase == 0` 是 **普通Python** 分支(build-time) +- `pto.if_context(is_first_k_tile, ...)` 在 IR 中生成 **runtime** 分支 + + +# 附录 B:NPU profiler 使用方法 + +`--kernel-name=` 参数里的 kernel 名字怎么找:先不带 `--kernel-name=` 跑一次 `msprof op`,输出里会直接打印 kernel 名。 + +完整官方文档见 [msProf](https://www.hiascend.com/document/detail/zh/canncommercial/850/devaids/optool/atlasopdev_16_0082.html)。 + +查看 profiler trace 的 UI 工具下载: + +```bash +# Windows x86 +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindStudio/MindStudio%208.3.0/MindStudio-Insight_8.3.0_win.exe + +# Mac arm and x86 +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindStudio/MindStudio%208.3.0/MindStudio-Insight_8.3.0_darwin-aarch64.dmg +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindStudio/MindStudio%208.3.0/MindStudio-Insight_8.3.0_darwin-x86_64.dmg + +# Linux arm and x86 +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindStudio/MindStudio%208.3.0/MindStudio-Insight_8.3.0_linux-aarch64.zip +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindStudio/MindStudio%208.3.0/MindStudio-Insight_8.3.0_linux-x86_64.zip +``` + +以上链接来自 [CANN 下载页](https://www.hiascend.com/developer/download/community/result?module=sto)。 diff --git a/examples/aot/matmul_optimization_guide/matmul_optim_guide.md b/examples/aot/matmul_optimization_guide/matmul_optim_guide.md new file mode 100644 index 00000000..f8311a67 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/matmul_optim_guide.md @@ -0,0 +1,341 @@ +# NPU Matmul kernel from scratch -- reaching CANN library performance using 100 lines of Python
(step-by-step optimization guide using PTO-ISA) + +- Date: 2026/03/12 +- Author: Jiawei Zhuang +- Contributor: Filip Skogh, Mirko De Vita, Hyun Min Chang + +# Outline + +- [Motivation](#motivation) +- [Step 0: NPU programming crash course for CUDA/Triton programmers](#step-0-npu-programming-crash-course-for-cudatriton-programmers) + - [Typical kernel launch syntax](#typical-kernel-launch-syntax) + - [Auto vs manual software pipelining](#auto-vs-manual-software-pipelining) +- [Step 1: Functionally-correct naive version](#step-1-functionally-correct-naive-version) +- [Step 2: Double buffering](#step-2-double-buffering) +- [Step 3: "Swizzling" for L2 cache reuse](#step-3-swizzling-for-l2-cache-reuse) +- [Step 4: (optional) Manual software pipelining](#step-4-optional-manual-software-pipelining) +- [Appendix A: PTO-DSL syntax note](#appendix-a-pto-dsl-syntax-note) +- [Appendix B: Using NPU profiler](#appendix-b-using-npu-profiler) + +**To reproduce all results shown in this guide**, see commands in [README.md](./README.md) + +# Motivation + +This guide is the NPU version of "step-by-step matmul optimization", a popular article style for NVIDIA GPUs (e.g. [for A100](https://siboehm.com/articles/22/CUDA-MMM) and [for H100](https://cudaforfun.substack.com/p/outperforming-cublas-on-h100-a-worklog)), but never written for our NPUs before. + +We show step-by-step how to match the performance of a carefully optimized official library, using **only ~100 lines of Python DSL**. The target to compare is `torch.matmul`, which invokes [aclnnMatmul](https://www.hiascend.com/document/detail/zh/canncommercial/850/API/aolapi/context/ops-nn/aclnnMatmul.md) (our "cuBLAS" for NPU), internally implemented by [many thousands of lines of AscendC](https://gitcode.com/cann/ops-nn/tree/v8.5.0/matmul/mat_mul_v3/op_kernel). + +I intentionally keep the code samples **minimal, hackable, from-scratch, and without magical templates and wrappers**, to highlight the few key optimizations. There are more comprehensive "Matmul optimizations lists" [in catlass](https://gitcode.com/cann/catlass/blob/v1.4.0/docs/contents/advanced/matmul_template_summary.md) or [in AscendC](https://www.hiascend.com/document/detail/zh/canncommercial/850/opdevg/Ascendcopdevg/atlas_ascendc_best_practices_10_10006.html), which hide optimization tricks behind templates and wrappers. They are more suited for later, more advanced study. + +# Step 0: NPU programming crash course for CUDA/Triton programmers + +(jump to the next section if you have programmed NPU kernels before) + +## Typical kernel launch syntax + +The [SPMD](https://en.wikipedia.org/wiki/Single_program,_multiple_data)-style kernels on NPU look **deceptively similar** to CUDA/Triton kernel syntax: +- The `block_idx` and `block_num` built-in variables assist offset calculations for each core -- [example here](https://github.com/huawei-csl/pto-dsl/blob/b9b0c4abdcb84b84db53f27ffcb4ce8aa1b67316/examples/jit/add_dynamic_multicore/run_add.py) +- The CUDA-style `kernel_name<<>>(args)` kernel launch -- [example here](https://github.com/huawei-csl/pto-dsl/blob/b9b0c4abdcb84b84db53f27ffcb4ce8aa1b67316/examples/aot/elementwise/add_dynamic_multicore/caller.cpp#L11) + +However, there is an important difference: all NPU kernels are ["persistent kernels"](https://triton-lang.org/main/getting-started/tutorials/09-persistent-matmul.html) in CUDA terminology, i.e. the `block_dim` is forced to be the number of cores instead of growing with the input data size. + +Check this [PTO dynamic-shape vector-add example](https://github.com/huawei-csl/pto-dsl/blob/b9b0c4abdcb84b84db53f27ffcb4ce8aa1b67316/examples/jit/add_dynamic_multicore/run_add.py#L46-L100) -- each core calculates its own global memory offsets, and the required number of iterations [depends dynamically on the input data size](https://github.com/huawei-csl/pto-dsl/blob/b9b0c4abdcb84b84db53f27ffcb4ce8aa1b67316/examples/jit/add_dynamic_multicore/run_add.py#L83). This is **unlike** conventional ("non-persistent") CUDA/Triton kernels, where a data-dependent `block_dim` handles the dynamic input size. For example, unlike [Triton vector add](https://triton-lang.org/main/getting-started/tutorials/01-vector-add.html#compute-kernel) that sets `grid = (ceil_div(n_elements, BLOCK_SIZE),)`, most of our NPU kernels (no matter whether they are written in PTO, AscendC, CCE, or other frameworks) always have `grid = (num_cores,)`. + +(A data-dependent large `block_dim` *might* work for simple cases on NPU, but it can often hit bugs during Cube-Vector synchronization, and can also overflow if `block_dim >= 65536` -- a bug [that we fixed](https://github.com/huawei-csl/pto-kernels/pull/39) by switching to persistent-kernel style.) + +## Auto vs manual software pipelining + +Our NPU uses on-chip [scratchpad memory](https://en.wikipedia.org/wiki/Scratchpad_memory) instead of hardware-managed cache, so [data hazards](https://en.wikipedia.org/wiki/Hazard_(computer_architecture)#Data_hazards) must be avoided by the programmer or software using [set_flag & wait_flag APIs](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/API/cceintrinsicapi/cceapi_0106.html), essentially a [binary-semaphore](https://en.wikipedia.org/wiki/Semaphore_(programming)#Producer%E2%80%93consumer_problem) synchronization mechanism. The closest analogy in CUDA is [all the `cp.async` stuff](https://docs.nvidia.com/cuda/cuda-programming-guide/04-special-topics/async-copies.html) that needs manual waits. See this [manually synchronized vector-add example](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/demos/torch_jit/add/add_custom.cpp#L78-L115). For complex fused kernels like [FlashAttention](https://github.com/PTO-ISA/pto-isa/tree/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/kernels/manual/common/flash_atten), it can be hard to reason about manual synchronization, software pipelining, and prefetching. + +To solve this headache, [PTO-DSL](https://github.com/huawei-csl/pto-dsl) offers automatic synchronization, internally achieved by the [InsertSync](https://github.com/zhangstevenunity/PTOAS/tree/v0.9/lib/PTO/Transforms/InsertSync) compile pass based on the [PTO MLIR dialect](https://github.com/zhangstevenunity/PTOAS/blob/v0.9/docs/PTO_IR_manual.md). The kernel code still looks "sequential" (in the pipelining dimension), similar to writing Triton or CuTile code. + +# Step 1: Functionally-correct naive version + +According to our [NPU hardware architecture](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/opdevg/Ascendcopdevg/atlas_ascendc_10_0008.html), a matmul operation requires this movement across the memory hierarchy: +- `GM` (global memory) -> `L1` -> `L0` (`L0A` or `L0B` for left or right operands) -> `Cube core` -> `L0C` -> `GM` + +The on-chip tile size (an algorithm parameter) is bounded by the L1/L0 SRAM size constraint (a hardware parameter). The [NPU hardware spec](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/opdevg/Ascendcopdevg/atlas_ascendc_10_0011.html) can be found in files `${ASCEND_HOME_PATH}/arm64-linux/data/platform_config/*.ini` in any CANN-installed environment: + +```bash +grep -A 9 "AICoreSpec" ${ASCEND_HOME_PATH}/arm64-linux/data/platform_config/Ascend910B2.ini +``` + +gives: + +``` +[AICoreSpec] +... +l0_a_size=65536 # 64 KiB +l0_b_size=65536 # 64 KiB +l0_c_size=131072 # 128 KiB +l1_size=524288 # 512 KiB +``` + +Consider the classic [tiled matrix multiplication](https://en.wikipedia.org/wiki/Loop_nest_optimization#Example:_matrix_multiplication) -- a general-shape matmul `C = A @ B` is implemented by tile-level operations over `A_tile = A[i1:i2,k1:k2]`, `B_tile = B[k1:k2,j1:j2]`, `C_tile = C[i1:i2,j1:j2]`, so that each tile fits into SRAM. Given the above SRAM info, we choose the tile sizes as: +- `[128 x 512]` for `A_tile` on `L1`, taking 128 KiB (fp16) +- `[256 x 256]` for `B_tile` on `L1`, taking 128 KiB (fp16) +- `[128 x 64]` for `A_tile` on `L0A`, taking 16 KiB (fp16) +- `[64 x 256]` for `B_tile` on `L0B`, taking 32 KiB (fp16) +- `[128 x 256]` for `C_tile` on `L0C`, taking 128 KiB (fp32 accumulation) +- The Cube unit performs the [`TMATMUL`](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TMATMUL.md) instruction of size `(M, N, K) = (128, 256, 64)`, taking `L0A` and `L0B` as input and `L0C` as output. + +Why choose these tile sizes: +- This is a common tiling choice [in the ATB library's matmul](https://gitcode.com/cann/ascend-transformer-boost/blob/br_release_cann_8.5.0_20260527/src/kernels/kernels/matmul/pp_matmul_f16_kernel/op_kernel/pp_matmul.cce?init=initTree), but many other choices also work as long as they fit into the buffers. +- The Cube unit prefers larger tile sizes for higher FLOPs utilization. For example, 128 x 128 typically achieves higher FLOPs than 32 x 32. For the full set of supported matmul shapes and dtypes, see the [`Mmad` instruction](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/850/API/ascendcopapi/atlasascendc_api_07_0249.html). +- We still have >=50% space left for `L1`, `L0A`, `L0B`. They are reserved for double-buffering later. + +See [step1_baseline_numpy_sim.py](./step1_baseline_numpy_sim.py) for the full "NumPy emulation code" that explains the algorithm logic. It's the most basic "split-MN matmul", where each core outputs its own `C_tile = C[i1:i2,j1:j2]`. We leave Split-K and Stream-K matmuls for future posts. The key code components are: +- The top-level loop `for li in range(core_loop):` comes from our "persistent kernel" requirement explained in [Typical kernel launch syntax](#typical-kernel-launch-syntax). Instead of having two-level "row and column loops", we bundle them together into a single-level `core_loop = n_loop * m_loop`, where each iteration can be independently assigned to a different core and completes its own `C_tile` calculation. +- Then we only need to accumulate over the inner K-dimension: + - The second-level loop `for k_idx in range(k_dtile_num)` is for "GM - L1 level" iterations. Once the current tile on `L1` is fully consumed by matmul and no longer needed, we load the next tile from `GM`. + - The third-level loop `for phase in range(8):` is for "L1 - L0 level" iterations. Once the current tile on `L0` is fully consumed by matmul and no longer needed, we load the next tile from `L1`. + - Notice that the third-level loop can be **statically unrolled** because we have a fixed ratio between `L1` and `L0` tile sizes. Because `L0` tiles are smaller than `L1` tiles, more than one "L0-level iteration" is required to match each "L1-level iteration". + +Then, we translate this NumPy emulation code into equivalent PTO-DSL code in [step1_baseline.py](./step1_baseline.py) and [common_utils.py](./common_utils.py). The PTO code logic largely follows the NumPy emulation, while using NPU-specific data movement and compute APIs: +- Use `pto.load` ([TLOAD](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TLOAD.md)) for `GM`->`L1` load +- Use `tile.extract` ([TEXTRACT](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TEXTRACT.md)) for `L1`->`L0A`, `L1`->`L0B` loads +- Use `tile.matmul`/`tile.matmul_acc` ([TMATMUL](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TMATMUL.md)/[TMATMUL_ACC](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TMATMUL_ACC.md)) for compute on `L0` +- Use `pto.store` ([TSTORE](https://github.com/PTO-ISA/pto-isa/blob/5de2d24d53e8cf39dec5fc11f997d1e74fa7190c/docs/isa/TSTORE.md)) for `L0C`->`GM` store +- Use native Python `for i in range()` for statically unrolled loop, and `for i in pto.range()` for run-time dynamic loop. Similarly for `if`/`else` branching. + +More DSL-specific syntax details are explained in [Appendix A: PTO-DSL syntax note](#appendix-a-pto-dsl-syntax-note). + +This simple 80-line PTO kernel produces numerically correct results on NPU, but the performance is only 50% of the `torch.matmul` reference. We will close the gap in the next section. + +![image info](./fig/flops_step1_baseline.png) + +# Step 2: Double buffering + +Profiling our previous kernel with `msprof op simulator`: + +```bash +msprof op simulator --aic-metrics=PipeUtilization \ + --kernel-name="_Z28matmul_kernel_step1_baselinePDhS_S_iii_mix_aic" \ + --output="msprof_res" --launch-count=5 \ + python ./run_matmul.py --variant step1-baseline +``` + +(see [Appendix B: Using NPU profiler](#appendix-b-using-npu-profiler) for more profiler usage details) + +We see that the Cube core is idle for 50% of the time: + +![image info](./fig/pipeline_N1024_baseline.png) + +Double buffering overlaps compute and data transfer: + +![image info](./fig/pipeline_N1024_doublebuf.png) + +See full code in [./step2_doublebuffer.py](./step2_doublebuffer.py). + +Profile with: + +
+ +```bash +msprof op simulator --aic-metrics=PipeUtilization \ + --kernel-name="_Z26matmul_kernel_ABt_autosyncPDhS_S_iii_mix_aic" \ + --output="msprof_res" --launch-count=5 \ + python ./run_matmul.py --variant step2-doublebuffer +``` + +
+ +The only difference is that we allocate 2x local buffers for `A_tile` and `B_tile` on both `L1` and `L0`: + +```python +a_l1 = [pto.alloc_tile(tile_buf_a_l1), pto.alloc_tile(tile_buf_a_l1)] +b_l1 = [pto.alloc_tile(tile_buf_b_l1), pto.alloc_tile(tile_buf_b_l1)] +a_l0 = [pto.alloc_tile(tile_buf_a_l0), pto.alloc_tile(tile_buf_a_l0)] +b_l0 = [pto.alloc_tile(tile_buf_b_l0), pto.alloc_tile(tile_buf_b_l0)] +``` + +and alternate between the "odd" and "even" buffers across iterations. + +Now the FLOPs are doubled for not-so-large matrices: +![image info](./fig/flops_step2_doublebuf.png) + +For large-enough matrices such as 16384x16384, the FLOPs **suddenly drop** because the NPU L2 cache is not large enough to hold the entire matrix, and the data gets evicted from cache. + +We can check the L2 cache size with: + +```bash +grep -A 8 "SoCInfo" ${ASCEND_HOME_PATH}/arm64-linux/data/platform_config/Ascend910B2.ini +``` + +gives: + +``` +[SoCInfo] +ai_core_cnt=24 +cube_core_cnt=24 +vector_core_cnt=48 +ai_cpu_cnt=6 +memory_type= +memory_size=68719476736 # 64 GiB +l2_type=0 +l2_size=201326592 # 192 MiB +``` + +An 8192x8192 matrix (64 MiB in float16) is smaller than L2, but a 16384x16384 matrix (256 MiB in float16) is larger than L2, so we see worse performance. + +For `910B4`, both HBM size and L2 cache size are smaller by half (thus the cache eviction effect happens for smaller matrices): + +```bash +grep -A 8 "SoCInfo" ${ASCEND_HOME_PATH}/arm64-linux/data/platform_config/Ascend910B4.ini +``` + +``` +[SoCInfo] +ai_core_cnt=20 +cube_core_cnt=20 +vector_core_cnt=40 +ai_cpu_cnt=6 +memory_type= +memory_size=34359738368 # 32 GiB +l2_type=0 +l2_size=100663296 # 96 MiB +``` + +# Step 3: "Swizzling" for L2 cache reuse + +Swizzling improves L2 cache reuse across multiple cores. We borrow this figure [from Triton matmul](https://triton-lang.org/main/getting-started/tutorials/03-matrix-multiplication.html#l2-cache-optimizations): +Grouped vs row-major ordering (from Triton) + +To read this figure, assume 9 cores computing a subset `C` matrix in the first iteration (the yellow area, each number 0 ~ 8 marks the core id). In the naive "row-major ordering", the full matrix B (assume larger than L2 cache!) needs to be loaded from global memory; while in the "grouped ordering", the data traffic w.r.t. global memory is much less. + +[step3_swizzle.py](./step3_swizzle.py) incorporates a 10-line swizzling function `swizzle_nz`, while keeping the rest of the code same as step2. [step3_swizzle_numpy_sim.py](./step3_swizzle_numpy_sim.py) explains the swizzle scheme intuitively. The swizzle algorithm is one of the algorithms [from catlass](https://gitcode.com/cann/catlass/blob/v1.4.0/include/catlass/gemm/block/block_swizzle.hpp), which also [has a nice explanation](https://gitcode.com/cann/catlass/blob/v1.4.0/docs/contents/advanced/swizzle_explanation.md) + +(for GPU experts -- such index remapping is analogous to [the "scheduler" in DeepGEMM](https://github.com/deepseek-ai/DeepGEMM/blob/v2.1.1/deep_gemm/include/deep_gemm/common/scheduler.cuh), which alters data assignment and loop order for each SM) + +With just this 10-line swizzle function, the FLOPs are much improved, reaching ~90% of `torch.matmul`! + +![image info](./fig/flops_step3_swizzle.png) + +To confirm the L2 cache effect, profile cache hit with `msprof op`: + +```bash +msprof op \ + --aic-metrics=Occupancy,Roofline,Default,L2Cache,PipeUtilization,MemoryL0 \ + --kernel-name="_Z26matmul_kernel_ABt_autosyncPDhS_S_iii_mix_aic" \ + --output="msprof_res" --launch-count=5 \ + python ./run_matmul.py --variant step3-swizzle +``` + +For a small 4096x4096 matrix, L2 cache hit is high (97.88%) even without a swizzled loop order: + +cachehit_N4096 + +For a larger 16384x16384 matrix that exceeds L2, L2 cache hit is low (30.9%) without swizzling: + +cachehit_N16384 + +With swizzling, the 16384x16384 case now gets a high (93.72%) L2 hit rate: + +cachehit_N16384_swizzle + + +# Step 4: (optional) Manual software pipelining + +The last 10% performance gap can be squeezed out by manual software pipelining in [./step4_manual_pipelining.py](./step4_manual_pipelining.py). + +![image info](./fig/flops_step4_manual_pipeline.png) + +Even with manual sync, the code only increases from ~100 lines to ~150 lines of Python, still much shorter than library code. How to manually arrange the sync flags is out of scope for this guide. We are [investigating the compile pass](https://github.com/zhangstevenunity/PTOAS/issues/226) so that compiler-inserted sync can eventually reach manual performance. + +# Appendix A: PTO-DSL syntax note + +The current [PTO-DSL package](https://github.com/huawei-csl/pto-dsl/tree/b9b0c4abdcb84b84db53f27ffcb4ce8aa1b67316/ptodsl) is just a very thin wrapper over the [MLIR Python bindings](https://mlir.llvm.org/docs/Bindings/Python/) of PTO dialect. The entire package has **only ~1000 lines of Python** (you can check by `cd ptodsl && find . -name "*.py" | xargs wc -l`). + +To keep the framework simple during rapid development, we are NOT using Python AST parsing or AST rewriting. Thus, all Python-native constructs (`if`/`for` control flows, Python classes, iterators, etc.) execute like normal Python code. This is unlike other pure-AST (the case for Triton & CuTile) or hybrid AST+tracing (the case for Tilelang & CuteDSL) frontends that *might or might not* rewrite native `if`/`range` as special IR builders (e.g. see the [complex rules for CuteDSL](https://github.com/Dao-AILab/quack/blob/v0.3.2/docs/dsl_control_flow.rst)). The current PTO-DSL frontend is pure Python tracing, most like JAX's approach. + +**Users should keep in mind:** run-time dynamic control flows are only available in the `pto` namespace such as `pto.range` (which creates [MLIR structured control flow](https://mlir.llvm.org/docs/Dialects/SCFDialect/) in the IR module), while Python native control flows are evaluated at build time. + +Common cases: + +- **Python `for ... in range(...)`** + - runs before generating the IR (build-time) + - usually acts like compile-time metaprogramming/unrolling +- **`for ... in pto.range(...)`** + - emits an MLIR `scf.for` loop + - executes dynamically at kernel run-time +- **Python `if condition:`** + - condition evaluated at build-time by Python + - branch is selected before generating IR +- **`with pto.if_context(cond):` / `pto.cond(...)`** + - emits runtime `scf.if` + - condition is evaluated when kernel runs + +**Example 1: `pto.range` (runtime loop in IR)** + +From `step1_baseline.py`: + +```python +for li in pto.range(bid, core_loop, num_blocks): + ... +``` + +This is **not** Python iteration over integers. In PTO-DSL, `pto.range` is an IR-builder primitive (see `control_flow.py`) that constructs `scf.ForOp` and yields an induction-variable value. + +Practical effect: +- loop trip count depends on runtime values like `bid`, `core_loop`, `num_blocks` +- loop stays as a loop in generated IR (not unrolled by Python) + +**Example 2: Python `range` (build-time unrolling)** + +From `step1_baseline.py`: + +```python +for phase in range(8): + ... +``` + +This loop is executed by Python while building IR, so it typically creates 8 repeated code regions in IR. + +For readers with C++ background: +- this is conceptually similar to compile-time code generation / metaprogramming +- useful when loop bounds are small constants + +**Example 3: Python `if` vs `pto.if_context`** + +From `step1_baseline.py`: + +```python +if phase == 0: + with pto.if_context(is_first_k_tile, has_else=True) as branch: + tile.matmul(a_l0, b_l0, c_l0) + with branch.else_context(): + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) +else: + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) +``` + +How to read this correctly: +- `if phase == 0` is a **Python** branch (build-time), because `phase` is a Python integer from `range(8)`. +- `pto.if_context(is_first_k_tile, ...)` emits a **runtime** branch in IR, because `is_first_k_tile` is a kernel scalar value. + +In plain words: +- first, Python decides which code shape to generate for each unrolled `phase` +- inside that shape, PTO-DSL inserts dynamic control flow for runtime conditions + +# Appendix B: Using NPU profiler + +How to find the kernel name for the `--kernel-name=` argument: first run `msprof op` without `--kernel-name=`, then it will print the kernel name. + +See the [full official doc for msProf](https://www.hiascend.com/document/detail/zh/canncommercial/850/devaids/optool/atlasopdev_16_0082.html). + +For the UI to inspect profiler traces, download with: + +```bash +# Windows x86 +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindStudio/MindStudio%208.3.0/MindStudio-Insight_8.3.0_win.exe + +# Mac arm and x86 +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindStudio/MindStudio%208.3.0/MindStudio-Insight_8.3.0_darwin-aarch64.dmg +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindStudio/MindStudio%208.3.0/MindStudio-Insight_8.3.0_darwin-x86_64.dmg + +# Linux arm and x86 +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindStudio/MindStudio%208.3.0/MindStudio-Insight_8.3.0_linux-aarch64.zip +wget https://ascend-repo.obs.cn-east-2.myhuaweicloud.com/MindStudio/MindStudio%208.3.0/MindStudio-Insight_8.3.0_linux-x86_64.zip +``` + +Those links are obtained from [this CANN download page](https://www.hiascend.com/developer/download/community/result?module=sto). diff --git a/examples/aot/matmul_optimization_guide/run_matmul.py b/examples/aot/matmul_optimization_guide/run_matmul.py new file mode 100644 index 00000000..99f4669b --- /dev/null +++ b/examples/aot/matmul_optimization_guide/run_matmul.py @@ -0,0 +1,214 @@ +import ctypes +import os +import argparse +from dataclasses import dataclass + +import torch +import torch.nn.functional as F +import torch_npu + +from ptodsl.test_util import get_test_device + + +BLOCK_DIM_LIST = [1, 20, 24] +M_LIST = [128 * i for i in range(1, 37, 4)] # 128, ..., 4224 +SHAPES_NK = [ + (4096, 4096), + (8192, 8192), + (16384, 16384), +] +MAX_ABSDIFF_THRESHOLD = 0.5 +MEAN_ABSDIFF_THRESHOLD = 1e-4 + + +@dataclass +class CaseResult: + m: int + n: int + k: int + block_dim: int + max_absdiff: float + mean_absdiff: float + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def load_lib(lib_path): + lib = ctypes.CDLL(os.path.abspath(lib_path)) + lib.call_kernel.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int, + ctypes.c_int, + ctypes.c_int, + ] + lib.call_kernel.restype = None + + def matmul_abt( + a, + b, + *, + block_dim=24, + stream_ptr=None, + ): + if a.ndim != 2 or b.ndim != 2: + raise ValueError("matmul_abt expects 2D tensors: a[M,K], b[N,K]") + if a.shape[1] != b.shape[1]: + raise ValueError( + f"K mismatch: a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)}" + ) + if a.dtype != torch.float16 or b.dtype != torch.float16: + raise ValueError("matmul_abt currently supports float16 inputs only") + + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + + m = int(a.shape[0]) + k = int(a.shape[1]) + n = int(b.shape[0]) + c = torch.empty((m, n), device=a.device, dtype=a.dtype) + + lib.call_kernel( + block_dim, + stream_ptr, + torch_to_ctypes(a), + torch_to_ctypes(b), + torch_to_ctypes(c), + m, + n, + k, + ) + return c + + return matmul_abt + + +def run_case(matmul_abt, a, b, c_ref, *, block_dim): + c = matmul_abt(a, b, block_dim=block_dim) + torch.npu.synchronize() + result = CaseResult( + m=int(a.shape[0]), + n=int(b.shape[0]), + k=int(a.shape[1]), + block_dim=block_dim, + max_absdiff=float((c - c_ref).abs().max().item()), + mean_absdiff=float((c - c_ref).abs().mean().item()), + ) + del c + torch.npu.empty_cache() + return result + + +def test_matmul(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--variant", + choices=[ + "step1-baseline", + "step2-doublebuffer", + "step3-swizzle", + "step4-manual-pipelining", + "all", + ], + default="all", + help="Which kernel variant to run.", + ) + args = parser.parse_args() + + device = get_test_device() + torch.npu.set_device(device) + + variants = { + "step1-baseline": "./build_artifacts/step1_baseline_kernel.so", + "step2-doublebuffer": "./build_artifacts/step2_doublebuffer_kernel.so", + "step3-swizzle": "./build_artifacts/step3_swizzle_kernel.so", + "step4-manual-pipelining": "./build_artifacts/step4_manual_pipelining_kernel.so", + } + if args.variant == "all": + selected = [ + ("step1-baseline", variants["step1-baseline"]), + ("step2-doublebuffer", variants["step2-doublebuffer"]), + ("step3-swizzle", variants["step3-swizzle"]), + ("step4-manual-pipelining", variants["step4-manual-pipelining"]), + ] + else: + selected = [(args.variant, variants[args.variant])] + + torch.manual_seed(0) + for variant_name, lib_path in selected: + print(f"\n=== Running variant: {variant_name} ({lib_path}) ===") + matmul_abt = load_lib(lib_path) + + checked_cases = 0 + global_worst = None + for m in M_LIST: + for n, k in SHAPES_NK: + a = torch.randn(m, k, dtype=torch.float16, device=device) + b = torch.randn(n, k, dtype=torch.float16, device=device) + c_ref = F.linear(a, b) + torch.npu.synchronize() + + shape_worst = None + for block_dim in BLOCK_DIM_LIST: + result = run_case(matmul_abt, a, b, c_ref, block_dim=block_dim) + checked_cases += 1 + + if ( + shape_worst is None + or result.max_absdiff > shape_worst.max_absdiff + or ( + result.max_absdiff == shape_worst.max_absdiff + and result.mean_absdiff > shape_worst.mean_absdiff + ) + ): + shape_worst = result + + if ( + global_worst is None + or result.max_absdiff > global_worst.max_absdiff + or ( + result.max_absdiff == global_worst.max_absdiff + and result.mean_absdiff > global_worst.mean_absdiff + ) + ): + global_worst = result + + del a, b, c_ref + torch.npu.empty_cache() + + print( + f"(m, n, k)=({m}, {n}, {k}) " + f"worst(block_dim)={shape_worst.block_dim} " + f"max_absdiff={shape_worst.max_absdiff:.6f} " + f"mean_absdiff={shape_worst.mean_absdiff:.6f}" + ) + + print(f"[{variant_name}] checked_cases={checked_cases}") + print( + f"[{variant_name}] global_worst " + f"max_absdiff={global_worst.max_absdiff:.6f} " + f"mean_absdiff={global_worst.mean_absdiff:.6f} " + f"at (m, n, k, block_dim)=" + f"({global_worst.m}, {global_worst.n}, {global_worst.k}, " + f"{global_worst.block_dim})" + ) + + if global_worst.max_absdiff > MAX_ABSDIFF_THRESHOLD: + raise AssertionError( + f"[{variant_name}] max_absdiff {global_worst.max_absdiff:.6f} exceeds " + f"threshold {MAX_ABSDIFF_THRESHOLD:.6f}" + ) + if global_worst.mean_absdiff > MEAN_ABSDIFF_THRESHOLD: + raise AssertionError( + f"[{variant_name}] mean_absdiff {global_worst.mean_absdiff:.6f} exceeds " + f"threshold {MEAN_ABSDIFF_THRESHOLD:.6f}" + ) + + +if __name__ == "__main__": + test_matmul() diff --git a/examples/aot/matmul_optimization_guide/step1_baseline.py b/examples/aot/matmul_optimization_guide/step1_baseline.py new file mode 100644 index 00000000..e192dd77 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/step1_baseline.py @@ -0,0 +1,138 @@ +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +from common_utils import ( + K_DTILE, + K_QTILE, + K_TILE, + M_TILE, + N_FULL, + build_meta_data, + const, +) + + +def build(): + meta_data = build_meta_data() + + @to_ir_module(meta_data=meta_data) + def matmul_kernel_step1_baseline( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + m_i32: "i32", + n_i32: "i32", + k_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + c128 = const(M_TILE) + c256 = const(N_FULL) + c512 = const(K_DTILE) + + m_total = s.index_cast(m_i32) + n_total = s.index_cast(n_i32) + k_total = s.index_cast(k_i32) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + + n_loop = (n_total + c256 - c1) // c256 + m_loop = m_total // c128 + core_loop = n_loop * m_loop + k_dtile_num = k_total // c512 + + tv_a = pto.as_tensor( + tv_2d, ptr=a_ptr, shape=[m_total, k_total], strides=[k_total, c1] + ) + tv_b = pto.as_tensor( + tv_2d, + ptr=b_ptr, + shape=[k_total, n_total], + strides=[c1, k_total], + layout="DN", + ) + tv_c = pto.as_tensor( + tv_2d, ptr=c_ptr, shape=[m_total, n_total], strides=[n_total, c1] + ) + + a_l1 = pto.alloc_tile(tile_buf_a_l1) + b_l1 = pto.alloc_tile(tile_buf_b_l1) + a_l0 = pto.alloc_tile(tile_buf_a_l0) + b_l0 = pto.alloc_tile(tile_buf_b_l0) + c_l0 = pto.alloc_tile(tile_buf_c) + + for li in pto.range(bid, core_loop, num_blocks): + m_idx = li // n_loop + n_idx = li % n_loop + m_offset = m_idx * c128 + n_offset = n_idx * c256 + c_kt = const(K_TILE) + c_kd = const(K_DTILE) + c_nt = const(N_FULL) + + sv_a0 = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, c0], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a0, a_l1) + + for k_idx in pto.range(c0, k_dtile_num, c1): + k_offset = k_idx * c_kd + is_first_k_tile = k_idx == c0 + + for phase in range(8): + if phase % 4 == 0: + b_half = phase // 4 + h_off = const(b_half * K_TILE) + sv_b = pto.slice_view( + tile_view_b, + source=tv_b, + offsets=[k_offset + h_off, n_offset], + sizes=[c_kt, c_nt], + ) + pto.load(sv_b, b_l1) + + a_col = const(phase * K_QTILE) + b_row = const((phase % 4) * K_QTILE) + tile.extract(a_l1, c0, a_col, a_l0) + tile.extract(b_l1, b_row, c0, b_l0) + + if phase == 0: + with pto.if_context( + is_first_k_tile, has_else=True + ) as branch: + tile.matmul(a_l0, b_l0, c_l0) + with branch.else_context(): + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) + else: + tile.matmul_acc(c_l0, a_l0, b_l0, c_l0) + + with pto.if_context(k_idx + c1 < k_dtile_num): + sv_a_next = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, k_offset + c_kd], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a_next, a_l1) + + sv_c = pto.slice_view( + tile_view_c, + source=tv_c, + offsets=[m_offset, n_offset], + sizes=[const(M_TILE), c_nt], + ) + pto.store(c_l0, sv_c) + + return matmul_kernel_step1_baseline + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _ = parser.parse_args() + print(build()) diff --git a/examples/aot/matmul_optimization_guide/step1_baseline_numpy_sim.py b/examples/aot/matmul_optimization_guide/step1_baseline_numpy_sim.py new file mode 100644 index 00000000..ec34901b --- /dev/null +++ b/examples/aot/matmul_optimization_guide/step1_baseline_numpy_sim.py @@ -0,0 +1,116 @@ +import numpy as np + +M_TILE = 128 +K_QTILE = 64 +K_TILE = 256 +K_DTILE = 512 +N_FULL = 256 + + +def _print_tile_memory(name, arr): + kib = arr.nbytes / 1024 + print( + f"[tile-mem] {name}: shape={arr.shape}, dtype={arr.dtype}, bytes={arr.nbytes} ({kib:.1f} KiB)" + ) + + +def step1_numpy_sim(a, b): + """ + a: [m, k] float16/float32 + b: [n, k] float16/float32 + returns c: [m, n], equivalent to a @ b.T + """ + m_total, k_total = a.shape + n_total, k_total_b = b.shape + assert k_total == k_total_b + assert m_total % M_TILE == 0, "Step1 kernel uses full M tiles in this demo." + assert k_total % K_DTILE == 0, "Step1 kernel uses full K_DTILE tiles." + assert n_total % N_FULL == 0, "Tutorial simulation assumes full N tiles." + + # Corresponds to: n_loop, m_loop, core_loop, k_dtile_num + n_loop = (n_total + N_FULL - 1) // N_FULL + m_loop = m_total // M_TILE + core_loop = n_loop * m_loop + k_dtile_num = k_total // K_DTILE + + c = np.zeros((m_total, n_total), dtype=np.float32) + + # Explicit tile-buffer allocation (mirrors pto.alloc_tile in step1_baseline.py). + # Keep shapes fixed to tutorial constants for easy hardware-memory cross-checks. + # a_l1: M_TILE * K_DTILE * sizeof(float16) = 128 * 512 * 2 = 131072 B = 128 KiB + a_l1 = np.empty((M_TILE, K_DTILE), dtype=np.float16) + # b_l1: K_TILE * N_FULL * sizeof(float16) = 256 * 256 * 2 = 131072 B = 128 KiB + b_l1 = np.empty((K_TILE, N_FULL), dtype=np.float16) + # a_l0: M_TILE * K_QTILE * sizeof(float16) = 128 * 64 * 2 = 16384 B = 16 KiB + a_l0 = np.empty((M_TILE, K_QTILE), dtype=np.float16) + # b_l0: K_QTILE * N_FULL * sizeof(float16) = 64 * 256 * 2 = 32768 B = 32 KiB + b_l0 = np.empty((K_QTILE, N_FULL), dtype=np.float16) + # c_tile: M_TILE * N_FULL * sizeof(float32) = 128 * 256 * 4 = 131072 B = 128 KiB + c_tile = np.empty((M_TILE, N_FULL), dtype=np.float32) + + _print_tile_memory("a_l1", a_l1) + _print_tile_memory("b_l1", b_l1) + _print_tile_memory("a_l0", a_l0) + _print_tile_memory("b_l0", b_l0) + _print_tile_memory("c_tile", c_tile) + + # Corresponds to: for li in pto.range(...) + for li in range(core_loop): + # Corresponds to: m_idx = li // n_loop; n_idx = li % n_loop + m_idx = li // n_loop + n_idx = li % n_loop + m_offset = m_idx * M_TILE + n_offset = n_idx * N_FULL + + # Corresponds to tile accumulator c_l0 (reused buffer, reset per output tile). + c_tile.fill(0.0) + + for k_idx in range(k_dtile_num): + k_offset = k_idx * K_DTILE + + # Prefetch A tile for current K chunk (equivalent to pto.load into a_l1). + a_l1[:, :] = a[m_offset : m_offset + M_TILE, k_offset : k_offset + K_DTILE] + + # Corresponds to: for phase in range(8) + for phase in range(8): + # Corresponds to loading one B half tile every 4 phases + if phase % 4 == 0: + b_half = phase // 4 + h_off = b_half * K_TILE + # b_l1 layout is [K_TILE, N_FULL], matching tile_buf_b_l1. + b_l1[:, :] = b[ + n_offset : n_offset + N_FULL, + k_offset + h_off : k_offset + h_off + K_TILE, + ].T + + # Corresponds to extract A/B quarter tiles + a_col = phase * K_QTILE + b_row = (phase % 4) * K_QTILE + a_l0[:, :] = a_l1[:, a_col : a_col + K_QTILE] + b_l0[:, :] = b_l1[b_row : b_row + K_QTILE, :] + + # Emulated tile matmul instruction: + # lhs a_l0: [M_TILE, K_QTILE] = [128, 64], fp16 source + # rhs b_l0: [K_QTILE, N_FULL] = [64, 256], fp16 source + # out c_tile: [M_TILE, N_FULL] = [128, 256], fp32 accumulate + # Keep tile storage in fp16; cast only right at matmul for fp16->fp32 accumulate. + c_tile += a_l0.astype(np.float32) @ b_l0.astype(np.float32) + + c[m_offset : m_offset + M_TILE, n_offset : n_offset + N_FULL] = c_tile + + return c + + +def test_step1_numpy_sim(): + np.random.seed(0) + for m, n, k in [(256, 512, 512), (384, 768, 1024)]: + a = np.random.randn(m, k).astype(np.float16) + b = np.random.randn(n, k).astype(np.float16) + c_ref = a.astype(np.float32) @ b.astype(np.float32).T + c_sim = step1_numpy_sim(a, b) + np.testing.assert_allclose(c_sim, c_ref, rtol=1e-4, atol=1e-3) + print("step1_numpy_sim unit test passed") + + +if __name__ == "__main__": + test_step1_numpy_sim() diff --git a/examples/aot/matmul_optimization_guide/step2_doublebuffer.py b/examples/aot/matmul_optimization_guide/step2_doublebuffer.py new file mode 100644 index 00000000..cb2344e6 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/step2_doublebuffer.py @@ -0,0 +1,153 @@ +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +from common_utils import ( + K_DTILE, + K_QTILE, + K_TILE, + M_TILE, + N_FULL, + build_meta_data, + const, +) + + +def build(): + meta_data = build_meta_data() + + @to_ir_module(meta_data=meta_data) + def matmul_kernel_ABt_autosync( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + m_i32: "i32", + n_i32: "i32", + k_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + c2 = const(2) + c128 = const(M_TILE) + c256 = const(N_FULL) + c512 = const(K_DTILE) + + m_total = s.index_cast(m_i32) + n_total = s.index_cast(n_i32) + k_total = s.index_cast(k_i32) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + + n_loop = (n_total + c256 - c1) // c256 + m_loop = m_total // c128 + core_loop = n_loop * m_loop + k_dtile_num = k_total // c512 + + tv_a = pto.as_tensor( + tv_2d, ptr=a_ptr, shape=[m_total, k_total], strides=[k_total, c1] + ) + tv_b = pto.as_tensor( + tv_2d, + ptr=b_ptr, + shape=[k_total, n_total], + strides=[c1, k_total], + layout="DN", + ) + tv_c = pto.as_tensor( + tv_2d, ptr=c_ptr, shape=[m_total, n_total], strides=[n_total, c1] + ) + + a_l1 = [pto.alloc_tile(tile_buf_a_l1), pto.alloc_tile(tile_buf_a_l1)] + b_l1 = [pto.alloc_tile(tile_buf_b_l1), pto.alloc_tile(tile_buf_b_l1)] + a_l0 = [pto.alloc_tile(tile_buf_a_l0), pto.alloc_tile(tile_buf_a_l0)] + b_l0 = [pto.alloc_tile(tile_buf_b_l0), pto.alloc_tile(tile_buf_b_l0)] + c_l0 = pto.alloc_tile(tile_buf_c) + + for li in pto.range(bid, core_loop, num_blocks): + m_idx = li // n_loop + n_idx = li % n_loop + + m_offset = m_idx * c128 + n_offset = n_idx * c256 + c_kt = const(K_TILE) + c_kd = const(K_DTILE) + c_nt = const(N_FULL) + + sv_a0 = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, c0], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a0, a_l1[0]) + + for k_idx in pto.range(c0, k_dtile_num, c1): + k_offset = k_idx * c_kd + + def run_loop_k(a_curr, a_next): + is_first_k_tile = k_idx == c0 + + for h in range(2): + h_off = const(h * K_TILE) + sv_b = pto.slice_view( + tile_view_b, + source=tv_b, + offsets=[k_offset + h_off, n_offset], + sizes=[c_kt, c_nt], + ) + pto.load(sv_b, b_l1[h]) + + for quarter in range(4): + phase = h * 4 + quarter + ping = phase & 1 + a_col = const(phase * K_QTILE) + b_row = const(quarter * K_QTILE) + + tile.extract(a_curr, c0, a_col, a_l0[ping]) + tile.extract(b_l1[h], b_row, c0, b_l0[ping]) + + if phase == 0: + pto.cond( + is_first_k_tile, + lambda: tile.matmul( + a_l0[ping], b_l0[ping], c_l0 + ), + lambda: tile.matmul_acc( + c_l0, a_l0[ping], b_l0[ping], c_l0 + ), + ) + else: + tile.matmul_acc(c_l0, a_l0[ping], b_l0[ping], c_l0) + + with pto.if_context(k_idx + c1 < k_dtile_num): + sv_a_next = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, k_offset + c_kd], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a_next, a_next) + + is_curr0 = (k_idx % c2) == c0 + with pto.if_context(is_curr0, has_else=True) as branch: + run_loop_k(a_l1[0], a_l1[1]) + with branch.else_context(): + run_loop_k(a_l1[1], a_l1[0]) + + sv_c = pto.slice_view( + tile_view_c, + source=tv_c, + offsets=[m_offset, n_offset], + sizes=[const(M_TILE), c_nt], + ) + pto.store(c_l0, sv_c) + + return matmul_kernel_ABt_autosync + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _ = parser.parse_args() + print(build()) diff --git a/examples/aot/matmul_optimization_guide/step3_swizzle.py b/examples/aot/matmul_optimization_guide/step3_swizzle.py new file mode 100644 index 00000000..93c0cf35 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/step3_swizzle.py @@ -0,0 +1,157 @@ +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +from common_utils import ( + K_DTILE, + K_QTILE, + K_TILE, + M_TILE, + N_FULL, + SWIZZLE_COUNT, + build_meta_data, + const, + swizzle_nz, +) + + +def build(): + meta_data = build_meta_data() + + @to_ir_module(meta_data=meta_data) + def matmul_kernel_ABt_autosync( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + m_i32: "i32", + n_i32: "i32", + k_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + c2 = const(2) + c128 = const(M_TILE) + c256 = const(N_FULL) + c512 = const(K_DTILE) + + m_total = s.index_cast(m_i32) + n_total = s.index_cast(n_i32) + k_total = s.index_cast(k_i32) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + + n_loop = (n_total + c256 - c1) // c256 + m_loop = m_total // c128 + core_loop = n_loop * m_loop + k_dtile_num = k_total // c512 + c_swizzle = const(SWIZZLE_COUNT) + c_swizzle_m1 = c_swizzle - c1 + + tv_a = pto.as_tensor( + tv_2d, ptr=a_ptr, shape=[m_total, k_total], strides=[k_total, c1] + ) + tv_b = pto.as_tensor( + tv_2d, + ptr=b_ptr, + shape=[k_total, n_total], + strides=[c1, k_total], + layout="DN", + ) + tv_c = pto.as_tensor( + tv_2d, ptr=c_ptr, shape=[m_total, n_total], strides=[n_total, c1] + ) + + a_l1 = [pto.alloc_tile(tile_buf_a_l1), pto.alloc_tile(tile_buf_a_l1)] + b_l1 = [pto.alloc_tile(tile_buf_b_l1), pto.alloc_tile(tile_buf_b_l1)] + a_l0 = [pto.alloc_tile(tile_buf_a_l0), pto.alloc_tile(tile_buf_a_l0)] + b_l0 = [pto.alloc_tile(tile_buf_b_l0), pto.alloc_tile(tile_buf_b_l0)] + c_l0 = pto.alloc_tile(tile_buf_c) + + for li in pto.range(bid, core_loop, num_blocks): + m_idx, n_idx = swizzle_nz( + li, m_loop, n_loop, c_swizzle, c_swizzle_m1, c1, c2 + ) + m_offset = m_idx * c128 + n_offset = n_idx * c256 + c_kt = const(K_TILE) + c_kd = const(K_DTILE) + c_nt = const(N_FULL) + + sv_a0 = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, c0], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a0, a_l1[0]) + + for k_idx in pto.range(c0, k_dtile_num, c1): + k_offset = k_idx * c_kd + + def run_loop_k(a_curr, a_next): + is_first_k_tile = k_idx == c0 + + for h in range(2): + h_off = const(h * K_TILE) + sv_b = pto.slice_view( + tile_view_b, + source=tv_b, + offsets=[k_offset + h_off, n_offset], + sizes=[c_kt, c_nt], + ) + pto.load(sv_b, b_l1[h]) + + for quarter in range(4): + phase = h * 4 + quarter + ping = phase & 1 + a_col = const(phase * K_QTILE) + b_row = const(quarter * K_QTILE) + + tile.extract(a_curr, c0, a_col, a_l0[ping]) + tile.extract(b_l1[h], b_row, c0, b_l0[ping]) + + if phase == 0: + pto.cond( + is_first_k_tile, + lambda: tile.matmul( + a_l0[ping], b_l0[ping], c_l0 + ), + lambda: tile.matmul_acc( + c_l0, a_l0[ping], b_l0[ping], c_l0 + ), + ) + else: + tile.matmul_acc(c_l0, a_l0[ping], b_l0[ping], c_l0) + + with pto.if_context(k_idx + c1 < k_dtile_num): + sv_a_next = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, k_offset + c_kd], + sizes=[const(M_TILE), c_kd], + ) + pto.load(sv_a_next, a_next) + + is_curr0 = (k_idx % c2) == c0 + with pto.if_context(is_curr0, has_else=True) as branch: + run_loop_k(a_l1[0], a_l1[1]) + with branch.else_context(): + run_loop_k(a_l1[1], a_l1[0]) + + sv_c = pto.slice_view( + tile_view_c, + source=tv_c, + offsets=[m_offset, n_offset], + sizes=[const(M_TILE), c_nt], + ) + pto.store(c_l0, sv_c) + + return matmul_kernel_ABt_autosync + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _ = parser.parse_args() + print(build()) diff --git a/examples/aot/matmul_optimization_guide/step3_swizzle_numpy_sim.py b/examples/aot/matmul_optimization_guide/step3_swizzle_numpy_sim.py new file mode 100644 index 00000000..42b78887 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/step3_swizzle_numpy_sim.py @@ -0,0 +1,53 @@ +import numpy as np + + +def swizzle_nz_py(li, m_loop, n_loop, c_swizzle, c1=1, c2=2): + c_swizzle_m1 = c_swizzle - c1 + tile_block_loop = (n_loop + c_swizzle_m1) // c_swizzle + tile_block_span = c_swizzle * m_loop + tile_block_idx = li // tile_block_span + in_tile_block_idx = li % tile_block_span + is_last_block = tile_block_idx == (tile_block_loop - c1) + n_col_tail = n_loop - c_swizzle * tile_block_idx + n_col = n_col_tail if is_last_block else c_swizzle + m_idx = in_tile_block_idx // n_col + n_idx = tile_block_idx * c_swizzle + (in_tile_block_idx % n_col) + odd_block = (tile_block_idx % c2) == c1 + if odd_block: + m_idx = m_loop - m_idx - c1 + return m_idx, n_idx + + +def show_mapping(m_loop, n_loop, c_swizzle, preview=24): + core_loop = m_loop * n_loop + rows = [] + linear_order_grid = np.full((m_loop, n_loop), -1, dtype=np.int32) + swizzle_order_grid = np.full((m_loop, n_loop), -1, dtype=np.int32) + for li in range(min(core_loop, preview)): + m_linear = li // n_loop + n_linear = li % n_loop + m_swz, n_swz = swizzle_nz_py(li, m_loop, n_loop, c_swizzle) + linear_order_grid[m_linear, n_linear] = li + swizzle_order_grid[m_swz, n_swz] = li + rows.append((li, m_linear, n_linear, m_swz, n_swz)) + + arr = np.array(rows, dtype=np.int32) + print( + f"\n=== swizzle={c_swizzle}, m_loop={m_loop}, n_loop={n_loop}, core_loop={core_loop} ===" + ) + print("li | linear(m,n) -> swizzle(m,n)") + for li, ml, nl, ms, ns in arr: + print(f"{li:2d} | ({ml:2d},{nl:2d}) -> ({ms:2d},{ns:2d})") + + print("\nLinear traversal order grid (value = li):") + print(linear_order_grid) + print("\nSwizzled traversal order grid (value = li):") + print(swizzle_order_grid) + + +if __name__ == "__main__": + # Use a non-multiple n_loop to demonstrate tail handling. + m_loop = 4 + n_loop = 7 + for c_swizzle in [2, 3, 5]: + show_mapping(m_loop=m_loop, n_loop=n_loop, c_swizzle=c_swizzle, preview=28) diff --git a/examples/aot/matmul_optimization_guide/step4_manual_pipelining.py b/examples/aot/matmul_optimization_guide/step4_manual_pipelining.py new file mode 100644 index 00000000..a92d7ed5 --- /dev/null +++ b/examples/aot/matmul_optimization_guide/step4_manual_pipelining.py @@ -0,0 +1,202 @@ +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +from common_utils import ( + K_DTILE, + K_QTILE, + K_TILE, + M_TILE, + N_FULL, + SWIZZLE_COUNT, + build_meta_data, + const, + swizzle_nz, +) + + +def build(): + meta_data = build_meta_data() + + @to_ir_module(meta_data=meta_data) + def matmul_kernel_ABt( + a_ptr: "ptr_type", + b_ptr: "ptr_type", + c_ptr: "ptr_type", + m_i32: "i32", + n_i32: "i32", + k_i32: "i32", + ) -> None: + with pto.cube_section(): + c0 = const(0) + c1 = const(1) + c2 = const(2) + c128 = const(M_TILE) + c256 = const(N_FULL) + c512 = const(K_DTILE) + + m_total = s.index_cast(m_i32) + n_total = s.index_cast(n_i32) + k_total = s.index_cast(k_i32) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + + n_loop = (n_total + c256 - c1) // c256 + m_loop = m_total // c128 + core_loop = n_loop * m_loop + k_dtile_num = k_total // c512 + c_swizzle = const(SWIZZLE_COUNT) + c_swizzle_m1 = c_swizzle - c1 + + tv_a = pto.as_tensor( + tv_2d, ptr=a_ptr, shape=[m_total, k_total], strides=[k_total, c1] + ) + tv_b = pto.as_tensor( + tv_2d, + ptr=b_ptr, + shape=[k_total, n_total], + strides=[c1, k_total], + layout="DN", + ) + tv_c = pto.as_tensor( + tv_2d, ptr=c_ptr, shape=[m_total, n_total], strides=[n_total, c1] + ) + + a_l1 = [pto.alloc_tile(tile_buf_a_l1), pto.alloc_tile(tile_buf_a_l1)] + b_l1 = [pto.alloc_tile(tile_buf_b_l1), pto.alloc_tile(tile_buf_b_l1)] + a_l0 = [pto.alloc_tile(tile_buf_a_l0), pto.alloc_tile(tile_buf_a_l0)] + b_l0 = [pto.alloc_tile(tile_buf_b_l0), pto.alloc_tile(tile_buf_b_l0)] + c_l0 = pto.alloc_tile(tile_buf_c) + + pto.record_event("MATMUL", "MOV_M2L", event_id=[0, 1]) + pto.record_event("MOV_M2L", "LOAD", event_id=[0, 1, 2, 3]) + + for li in pto.range(bid, core_loop, num_blocks): + m_idx, n_idx = swizzle_nz( + li, m_loop, n_loop, c_swizzle, c_swizzle_m1, c1, c2 + ) + m_offset = m_idx * c128 + n_offset = n_idx * c256 + c_kt = const(K_TILE) + c_kd = const(K_DTILE) + c_nt = const(N_FULL) + + not_first_tile = li != bid + with pto.if_context(not_first_tile): + pto.wait_event("STORE_ACC", "MATMUL", event_id=0) + + sv_a0 = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, c0], + sizes=[const(M_TILE), c_kd], + ) + pto.wait_event("MOV_M2L", "LOAD", event_id=0) + pto.load(sv_a0, a_l1[0]) + pto.record_event("LOAD", "MOV_M2L", event_id=0) + + for k_idx in pto.range(c0, k_dtile_num, c1): + k_offset = k_idx * c_kd + + def run_loop_k(curr_id, next_id, a_curr, a_next): + is_first_k_tile = k_idx == c0 + + for h in range(2): + b_evt = 2 + h + h_off = const(h * K_TILE) + sv_b = pto.slice_view( + tile_view_b, + source=tv_b, + offsets=[k_offset + h_off, n_offset], + sizes=[c_kt, c_nt], + ) + + pto.wait_event("MOV_M2L", "LOAD", event_id=b_evt) + pto.load(sv_b, b_l1[h]) + pto.record_event("LOAD", "MOV_M2L", event_id=b_evt) + + for quarter in range(4): + phase = h * 4 + quarter + ping = phase & 1 + a_col = const(phase * K_QTILE) + b_row = const(quarter * K_QTILE) + + pto.wait_event("MATMUL", "MOV_M2L", event_id=ping) + if phase == 0: + pto.wait_event("LOAD", "MOV_M2L", event_id=curr_id) + + tile.extract(a_curr, c0, a_col, a_l0[ping]) + if phase == 7: + pto.record_event( + "MOV_M2L", "LOAD", event_id=curr_id + ) + + if quarter == 0: + pto.wait_event("LOAD", "MOV_M2L", event_id=b_evt) + + tile.extract(b_l1[h], b_row, c0, b_l0[ping]) + pto.record_event("MOV_M2L", "MATMUL", event_id=0) + + if quarter == 3: + pto.record_event("MOV_M2L", "LOAD", event_id=b_evt) + + pto.wait_event("MOV_M2L", "MATMUL", event_id=0) + if phase == 0: + pto.cond( + is_first_k_tile, + lambda: tile.matmul( + a_l0[ping], b_l0[ping], c_l0 + ), + lambda: tile.matmul_acc( + c_l0, a_l0[ping], b_l0[ping], c_l0 + ), + ) + else: + tile.matmul_acc(c_l0, a_l0[ping], b_l0[ping], c_l0) + + pto.record_event("MATMUL", "MOV_M2L", event_id=ping) + + with pto.if_context(k_idx + c1 < k_dtile_num): + sv_a_next = pto.slice_view( + tile_view_a, + source=tv_a, + offsets=[m_offset, k_offset + c_kd], + sizes=[const(M_TILE), c_kd], + ) + pto.wait_event("MOV_M2L", "LOAD", event_id=next_id) + pto.load(sv_a_next, a_next) + pto.record_event("LOAD", "MOV_M2L", event_id=next_id) + + is_curr0 = (k_idx % c2) == c0 + with pto.if_context(is_curr0, has_else=True) as branch: + run_loop_k(0, 1, a_l1[0], a_l1[1]) + with branch.else_context(): + run_loop_k(1, 0, a_l1[1], a_l1[0]) + + sv_c = pto.slice_view( + tile_view_c, + source=tv_c, + offsets=[m_offset, n_offset], + sizes=[const(M_TILE), c_nt], + ) + pto.record_wait_pair("MATMUL", "STORE_ACC", event_id=0) + pto.store(c_l0, sv_c) + + with pto.if_context(li + num_blocks < core_loop): + pto.record_event("STORE_ACC", "MATMUL", event_id=0) + + pto.wait_event("MOV_M2L", "LOAD", event_id=3) + pto.wait_event("MOV_M2L", "LOAD", event_id=2) + pto.wait_event("MOV_M2L", "LOAD", event_id=1) + pto.wait_event("MOV_M2L", "LOAD", event_id=0) + pto.wait_event("MATMUL", "MOV_M2L", event_id=0) + pto.wait_event("MATMUL", "MOV_M2L", event_id=1) + + return matmul_kernel_ABt + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + _ = parser.parse_args() + print(build()) diff --git a/examples/aot/print_tile/.gitignore b/examples/aot/print_tile/.gitignore new file mode 100644 index 00000000..1c66f912 --- /dev/null +++ b/examples/aot/print_tile/.gitignore @@ -0,0 +1 @@ +print_gen.cpp diff --git a/examples/aot/print_tile/README.md b/examples/aot/print_tile/README.md new file mode 100644 index 00000000..3335fd00 --- /dev/null +++ b/examples/aot/print_tile/README.md @@ -0,0 +1,6 @@ +Usage: + +```bash +bash compile.sh +python ./run_print.py +``` diff --git a/examples/aot/print_tile/caller.cpp b/examples/aot/print_tile/caller.cpp new file mode 100644 index 00000000..6ae57e88 --- /dev/null +++ b/examples/aot/print_tile/caller.cpp @@ -0,0 +1,9 @@ +#include "print_gen.cpp" + +extern "C" void call_kernel( + void *stream, uint8_t *x, uint8_t *y, uint8_t *z, int32_t vrow, int32_t vcol) +{ + vec_add_kernel_2d_dynamic<<<2, nullptr, stream>>>( + (float *)x, (float *)y, (float *)z, vrow, vcol + ); +} diff --git a/examples/aot/print_tile/compile.sh b/examples/aot/print_tile/compile.sh new file mode 100644 index 00000000..9a55a8c4 --- /dev/null +++ b/examples/aot/print_tile/compile.sh @@ -0,0 +1,39 @@ + +#!/usr/bin/env bash +set -e + +ARCH=$(uname -m) + +PTO_DIR="$ASCEND_HOME_PATH/include/pto" +PTO_BACKUP="$ASCEND_HOME_PATH/include/pto_hidden" +PTO_LIB_PATH="/sources/pto-isa" +[ -d "$PTO_LIB_PATH" ] || exit 0 + + +rm -f print_lib.so print_gen.cpp +python ./print_builder.py | ptoas --enable-insert-sync > print_gen.cpp + +restore() { + if [ -d "$PTO_BACKUP" ]; then + mv "$PTO_BACKUP" "$PTO_DIR" + fi +} + +# For now we have to hide the CANN built-in headers, and use the cloned pto-isa's +# c.f. https://gitcode.com/cann/pto-isa/issues/149 +mv "$PTO_DIR" "$PTO_BACKUP" + +# Make restore run on EXIT +trap restore EXIT + +bisheng \ + -I${ASCEND_TOOLKIT_HOME}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -xcce -Xhost-start -Xhost-end \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -D_DEBUG --cce-enable-print \ + -I${ASCEND_HOME_PATH}/${ARCH}-linux/pkg_inc/runtime/runtime \ + -I${PTO_LIB_PATH}/include \ + -std=gnu++17 \ + ./caller.cpp \ + -o ./print_lib.so diff --git a/examples/aot/print_tile/print_builder.py b/examples/aot/print_tile/print_builder.py new file mode 100644 index 00000000..9321fb33 --- /dev/null +++ b/examples/aot/print_tile/print_builder.py @@ -0,0 +1,87 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + + +def meta_data(): + # common, reusable type declarations + dtype = pto.float32 + index_dtype = pto.int32 + ptr_type = pto.PtrType(dtype) + tensor_type = pto.TensorType(rank=2, dtype=dtype) + subtensor_type = pto.SubTensorType( + shape=[32, 32], dtype=dtype + ) # TODO: omit shape https://github.com/zhangstevenunity/PTOAS/issues/31 + tile_cfg = pto.TileBufConfig() + # defaults to pto.TileBufConfig(blayout="RowMajor", slayout="NoneBox", s_fractal_size=512, pad="Null") + tile_type = pto.TileBufType( + shape=[32, 32], + valid_shape=[-1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + return { + "ptr_type": ptr_type, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_type": subtensor_type, + "tile_type": tile_type, + } + + +@to_ir_module(meta_data=meta_data) +def vec_add_kernel_2d_dynamic( + arg0: "ptr_type", + arg1: "ptr_type", + arg2: "ptr_type", + arg_vrow_i32: "index_dtype", + arg_vcol_i32: "index_dtype", +) -> None: + c0 = const(0) + c1 = const(1) + c32 = const(32) + c1280 = const(1280) + + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + cidmul = cid * sub_bnum + vid = cidmul + sub_bid + + v_row_idx = s.index_cast(arg_vrow_i32) + v_col_idx = s.index_cast(arg_vcol_i32) + + tv0 = pto.as_tensor(tensor_type, ptr=arg0, shape=[c1280, c32], strides=[c32, c1]) + tv1 = pto.as_tensor(tensor_type, ptr=arg1, shape=[c1280, c32], strides=[c32, c1]) + tv2 = pto.as_tensor(tensor_type, ptr=arg2, shape=[c1280, c32], strides=[c32, c1]) + + vid_idx = s.index_cast(vid) + offset_row = vid_idx * c32 # every core loads 32 rows of data + sv0 = pto.slice_view( + subtensor_type, source=tv0, offsets=[offset_row, c0], sizes=[c32, c32] + ) + sv1 = pto.slice_view( + subtensor_type, source=tv1, offsets=[offset_row, c0], sizes=[c32, c32] + ) + sv2 = pto.slice_view( + subtensor_type, source=tv2, offsets=[offset_row, c0], sizes=[c32, c32] + ) + + with pto.vector_section(): + tb0 = pto.alloc_tile(tile_type, valid_row=v_row_idx, valid_col=v_col_idx) + tb1 = pto.alloc_tile(tile_type, valid_row=v_row_idx, valid_col=v_col_idx) + tb2 = pto.alloc_tile(tile_type, valid_row=v_row_idx, valid_col=v_col_idx) + + pto.load(sv0, tb0) + pto.load(sv1, tb1) + pto.print("hello%d\n", c1) + tile.print(tb0) + tile.add(tb0, tb1, tb2) + pto.store(tb2, sv2) + + +if __name__ == "__main__": + module = vec_add_kernel_2d_dynamic + print(module) diff --git a/examples/aot/print_tile/run_print.py b/examples/aot/print_tile/run_print.py new file mode 100644 index 00000000..1316bc19 --- /dev/null +++ b/examples/aot/print_tile/run_print.py @@ -0,0 +1,51 @@ +import ctypes +import torch +import torch_npu +from ptodsl.test_util import get_test_device + + +def torch_to_ctypes(tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +def lib_to_func(lib): + def add_func(x, y, z, stream_ptr=None): + + vrow, vcol = 32, 32 # local tile shape hard-coded as the kernel + + if stream_ptr is None: + stream_ptr = torch.npu.current_stream()._as_parameter_ + + lib.call_kernel( + stream_ptr, + torch_to_ctypes(x), + torch_to_ctypes(y), + torch_to_ctypes(z), + vrow, + vcol, + ) + + return add_func + + +def test_add(): + device = get_test_device() + torch.npu.set_device(device) + + lib_path = "./print_lib.so" + lib = ctypes.CDLL(lib_path) + add_func = lib_to_func(lib) + + shape = [1280, 32] # tensor shape hard-coded as the kernel + torch.manual_seed(0) + dtype = torch.float32 + x = torch.arange(shape[0] * shape[1], device=device, dtype=dtype).reshape(shape) + y = torch.arange(shape[0] * shape[1], device=device, dtype=dtype).reshape(shape) + z = torch.empty(shape, device=device, dtype=dtype) + + add_func(x, y, z) + torch.npu.synchronize() + + +if __name__ == "__main__": + test_add() diff --git a/examples/aot/add_static_multicore/.gitignore b/examples/aot/simple_static/add_static_multicore/.gitignore similarity index 50% rename from examples/aot/add_static_multicore/.gitignore rename to examples/aot/simple_static/add_static_multicore/.gitignore index 79b9aff4..1a4c7666 100644 --- a/examples/aot/add_static_multicore/.gitignore +++ b/examples/aot/simple_static/add_static_multicore/.gitignore @@ -1,2 +1,2 @@ add.cpp -add.pto \ No newline at end of file +add.pto diff --git a/examples/aot/add_static_multicore/README.md b/examples/aot/simple_static/add_static_multicore/README.md similarity index 100% rename from examples/aot/add_static_multicore/README.md rename to examples/aot/simple_static/add_static_multicore/README.md diff --git a/examples/aot/add_static_multicore/add_builder.py b/examples/aot/simple_static/add_static_multicore/add_builder.py similarity index 66% rename from examples/aot/add_static_multicore/add_builder.py rename to examples/aot/simple_static/add_static_multicore/add_builder.py index 0525f4c7..1c790077 100644 --- a/examples/aot/add_static_multicore/add_builder.py +++ b/examples/aot/simple_static/add_static_multicore/add_builder.py @@ -1,6 +1,7 @@ -from ptodsl import to_ir_module -import ptodsl.language as pto -const = pto.const +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const def meta_data(): @@ -9,11 +10,18 @@ def meta_data(): index_dtype = pto.int32 ptr_type = pto.PtrType(dtype) tensor_type = pto.TensorType(rank=2, dtype=dtype) - subtensor_type = pto.SubTensorType(shape=[32, 32], dtype=dtype) # TODO: omit shape https://github.com/zhangstevenunity/PTOAS/issues/31 + subtensor_type = pto.SubTensorType( + shape=[32, 32], dtype=dtype + ) # TODO: omit shape https://github.com/zhangstevenunity/PTOAS/issues/31 tile_cfg = pto.TileBufConfig() # defaults to pto.TileBufConfig(blayout="RowMajor", slayout="NoneBox", s_fractal_size=512, pad="Null") tile_type = pto.TileBufType( - shape=[32, 32], valid_shape=[-1, -1], dtype=dtype, memory_space="VEC", config=tile_cfg) + shape=[32, 32], + valid_shape=[-1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) return { "ptr_type": ptr_type, "index_dtype": index_dtype, @@ -29,8 +37,8 @@ def vec_add_kernel_2d_dynamic( arg1: "ptr_type", arg2: "ptr_type", arg_vrow_i32: "index_dtype", - arg_vcol_i32: "index_dtype" - ) -> None: + arg_vcol_i32: "index_dtype", +) -> None: c0 = const(0) c1 = const(1) c32 = const(32) @@ -42,18 +50,24 @@ def vec_add_kernel_2d_dynamic( cidmul = cid * sub_bnum vid = cidmul + sub_bid - v_row_idx = pto.index_cast(arg_vrow_i32) - v_col_idx = pto.index_cast(arg_vcol_i32) + v_row_idx = s.index_cast(arg_vrow_i32) + v_col_idx = s.index_cast(arg_vcol_i32) tv0 = pto.as_tensor(tensor_type, ptr=arg0, shape=[c1280, c32], strides=[c32, c1]) tv1 = pto.as_tensor(tensor_type, ptr=arg1, shape=[c1280, c32], strides=[c32, c1]) tv2 = pto.as_tensor(tensor_type, ptr=arg2, shape=[c1280, c32], strides=[c32, c1]) - vid_idx = pto.index_cast(vid) + vid_idx = s.index_cast(vid) offset_row = vid_idx * c32 # every core loads 32 rows of data - sv0 = pto.slice_view(subtensor_type, source=tv0, offsets=[offset_row, c0], sizes=[c32, c32]) - sv1 = pto.slice_view(subtensor_type, source=tv1, offsets=[offset_row, c0], sizes=[c32, c32]) - sv2 = pto.slice_view(subtensor_type, source=tv2, offsets=[offset_row, c0], sizes=[c32, c32]) + sv0 = pto.slice_view( + subtensor_type, source=tv0, offsets=[offset_row, c0], sizes=[c32, c32] + ) + sv1 = pto.slice_view( + subtensor_type, source=tv1, offsets=[offset_row, c0], sizes=[c32, c32] + ) + sv2 = pto.slice_view( + subtensor_type, source=tv2, offsets=[offset_row, c0], sizes=[c32, c32] + ) with pto.vector_section(): tb0 = pto.alloc_tile(tile_type, valid_row=v_row_idx, valid_col=v_col_idx) @@ -62,7 +76,7 @@ def vec_add_kernel_2d_dynamic( pto.load(sv0, tb0) pto.load(sv1, tb1) - pto.add(tb0, tb1, tb2) + tile.add(tb0, tb1, tb2) pto.store(tb2, sv2) # `default `return None` maps to `func.ReturnOp([])` diff --git a/examples/aot/add_static_multicore/caller.cpp b/examples/aot/simple_static/add_static_multicore/caller.cpp similarity index 100% rename from examples/aot/add_static_multicore/caller.cpp rename to examples/aot/simple_static/add_static_multicore/caller.cpp diff --git a/examples/aot/add_static_multicore/compile.sh b/examples/aot/simple_static/add_static_multicore/compile.sh similarity index 100% rename from examples/aot/add_static_multicore/compile.sh rename to examples/aot/simple_static/add_static_multicore/compile.sh diff --git a/examples/aot/add_static_multicore/run_add.py b/examples/aot/simple_static/add_static_multicore/run_add.py similarity index 91% rename from examples/aot/add_static_multicore/run_add.py rename to examples/aot/simple_static/add_static_multicore/run_add.py index 44c42615..53344ff8 100644 --- a/examples/aot/add_static_multicore/run_add.py +++ b/examples/aot/simple_static/add_static_multicore/run_add.py @@ -9,12 +9,7 @@ def torch_to_ctypes(tensor): def lib_to_func(lib): - def add_func( - x, - y, - z, - stream_ptr=None - ): + def add_func(x, y, z, stream_ptr=None): vrow, vcol = 32, 32 # local tile shape hard-coded as the kernel @@ -26,8 +21,10 @@ def add_func( torch_to_ctypes(x), torch_to_ctypes(y), torch_to_ctypes(z), - vrow, vcol + vrow, + vcol, ) + return add_func @@ -53,5 +50,6 @@ def test_add(): torch.testing.assert_close(z, z_ref) print("result equal!") + if __name__ == "__main__": test_add() diff --git a/examples/aot/matmul_static_singlecore/.gitignore b/examples/aot/simple_static/matmul_static_singlecore/.gitignore similarity index 100% rename from examples/aot/matmul_static_singlecore/.gitignore rename to examples/aot/simple_static/matmul_static_singlecore/.gitignore diff --git a/examples/aot/simple_static/matmul_static_singlecore/README.md b/examples/aot/simple_static/matmul_static_singlecore/README.md new file mode 100644 index 00000000..17fc4de9 --- /dev/null +++ b/examples/aot/simple_static/matmul_static_singlecore/README.md @@ -0,0 +1,4 @@ +```bash +bash ./compile.sh +python ./run_matmul.py +``` diff --git a/examples/aot/matmul_static_singlecore/caller.cpp b/examples/aot/simple_static/matmul_static_singlecore/caller.cpp similarity index 100% rename from examples/aot/matmul_static_singlecore/caller.cpp rename to examples/aot/simple_static/matmul_static_singlecore/caller.cpp diff --git a/examples/aot/matmul_static_singlecore/compile.sh b/examples/aot/simple_static/matmul_static_singlecore/compile.sh similarity index 100% rename from examples/aot/matmul_static_singlecore/compile.sh rename to examples/aot/simple_static/matmul_static_singlecore/compile.sh diff --git a/examples/aot/matmul_static_singlecore/matmul_builder.py b/examples/aot/simple_static/matmul_static_singlecore/matmul_builder.py similarity index 70% rename from examples/aot/matmul_static_singlecore/matmul_builder.py rename to examples/aot/simple_static/matmul_static_singlecore/matmul_builder.py index ab5e2ffa..6b11b015 100644 --- a/examples/aot/matmul_static_singlecore/matmul_builder.py +++ b/examples/aot/simple_static/matmul_static_singlecore/matmul_builder.py @@ -1,7 +1,7 @@ # adapted from https://github.com/zhangstevenunity/PTOAS/blob/a301aa43b388d9b2e1ba0db8773b3a719e8c445b/test/samples/MatMul/tmatmulk.py -from ptodsl import to_ir_module -import ptodsl.language as pto +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s def build( @@ -28,11 +28,21 @@ def meta_data(): tile_view_out = pto.SubTensorType(shape=[M, N], dtype=dtype) tile_view_bias = pto.SubTensorType(shape=[1, N], dtype=dtype) - tile_buf_aMat = pto.TileBufType(shape=[M, BASEK], dtype=dtype, memory_space="MAT") - tile_buf_bMat = pto.TileBufType(shape=[BASEK, N], dtype=dtype, memory_space="MAT") - tile_buf_biasData = pto.TileBufType(shape=[1, N], dtype=dtype, memory_space="MAT") - tile_buf_aTile = pto.TileBufType(shape=[M, BASEK], dtype=dtype, memory_space="LEFT") - tile_buf_bTile = pto.TileBufType(shape=[BASEK, N], dtype=dtype, memory_space="RIGHT") + tile_buf_aMat = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="MAT" + ) + tile_buf_bMat = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="MAT" + ) + tile_buf_biasData = pto.TileBufType( + shape=[1, N], dtype=dtype, memory_space="MAT" + ) + tile_buf_aTile = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="LEFT" + ) + tile_buf_bTile = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="RIGHT" + ) tile_buf_cTile = pto.TileBufType(shape=[M, N], dtype=dtype, memory_space="ACC") tile_buf_biasTile = pto.TileBufType( shape=[1, N], dtype=dtype, memory_space="BIAS" @@ -55,7 +65,7 @@ def meta_data(): "tile_buf_biasTile": tile_buf_biasTile, } - const = pto.const + const = s.const @to_ir_module(meta_data=meta_data) def RunTMATMULSplitK( @@ -76,10 +86,18 @@ def RunTMATMULSplitK( cTileM = const(M) cTileN = const(N) - tvA = pto.as_tensor(tensor_type, ptr=a_ptr, shape=[cM, cK], strides=[cK, c1]) - tvB = pto.as_tensor(tensor_type, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1]) - tvOut = pto.as_tensor(tensor_type, ptr=out_ptr, shape=[cM, cN], strides=[cN, c1]) - tvBias = pto.as_tensor(tensor_type, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1]) + tvA = pto.as_tensor( + tensor_type, ptr=a_ptr, shape=[cM, cK], strides=[cK, c1] + ) + tvB = pto.as_tensor( + tensor_type, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1] + ) + tvOut = pto.as_tensor( + tensor_type, ptr=out_ptr, shape=[cM, cN], strides=[cN, c1] + ) + tvBias = pto.as_tensor( + tensor_type, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1] + ) aMatTile = pto.alloc_tile(tile_buf_aMat) bMatTile = pto.alloc_tile(tile_buf_bMat) @@ -89,7 +107,7 @@ def RunTMATMULSplitK( cTile = pto.alloc_tile(tile_buf_cTile) biasTile = pto.alloc_tile(tile_buf_biasTile) - for i in pto.for_range(c0, cIter, c1): + for i in pto.range(c0, cIter, c1): kOff = i * cBASEK svA = pto.slice_view( tile_view_a, @@ -115,24 +133,24 @@ def RunTMATMULSplitK( with pto.if_context(isBias): pto.load(svBias, biasDataTile) - pto.mov(aMatTile, aTile) - pto.mov(bMatTile, bTile) + tile.mov(aMatTile, aTile) + tile.mov(bMatTile, bTile) with pto.if_context(isBias): - pto.mov(biasDataTile, biasTile) + tile.mov(biasDataTile, biasTile) - is_i0 = pto.eq(i, c0) + is_i0 = s.eq(i, c0) def _first_iter(): pto.cond( isBias, - lambda: pto.matmul_bias(aTile, bTile, biasTile, cTile), - lambda: pto.matmul(aTile, bTile, cTile), + lambda: tile.matmul_bias(aTile, bTile, biasTile, cTile), + lambda: tile.matmul(aTile, bTile, cTile), ) pto.cond( is_i0, _first_iter, - lambda: pto.matmul_acc(cTile, aTile, bTile, cTile), + lambda: tile.matmul_acc(cTile, aTile, bTile, cTile), ) svOut = pto.slice_view( @@ -148,4 +166,4 @@ def _first_iter(): if __name__ == "__main__": - print(build()) \ No newline at end of file + print(build()) diff --git a/examples/aot/matmul_static_singlecore/run_matmul.py b/examples/aot/simple_static/matmul_static_singlecore/run_matmul.py similarity index 81% rename from examples/aot/matmul_static_singlecore/run_matmul.py rename to examples/aot/simple_static/matmul_static_singlecore/run_matmul.py index 46822a5c..1eddc067 100644 --- a/examples/aot/matmul_static_singlecore/run_matmul.py +++ b/examples/aot/simple_static/matmul_static_singlecore/run_matmul.py @@ -13,11 +13,7 @@ def load_lib(lib_path): default_block_dim = 1 # NOTE: kernel is single-core for now - def matmul_func( - c, a, b, - block_dim=default_block_dim, - stream_ptr=None - ): + def matmul_func(c, a, b, block_dim=default_block_dim, stream_ptr=None): if stream_ptr is None: stream_ptr = torch.npu.current_stream()._as_parameter_ lib.call_kernel( @@ -39,8 +35,8 @@ def test_matmul(): m, k, n = 32, 256, 32 torch.manual_seed(0) - a = torch.rand((m,k), device=device, dtype=dtype) - b = torch.rand((k,n), device=device, dtype=dtype) + a = torch.rand((m, k), device=device, dtype=dtype) + b = torch.rand((k, n), device=device, dtype=dtype) c = torch.zeros((m, n), device=device, dtype=dtype) matmul_func = load_lib("./matmul_kernel.so") @@ -49,7 +45,7 @@ def test_matmul(): c_ref = torch.matmul(a, b) diff = (c - c_ref).abs().max() - print('max diff: ', diff) + print("max diff: ", diff) if __name__ == "__main__": diff --git a/examples/aot/topk/.gitignore b/examples/aot/topk/.gitignore new file mode 100644 index 00000000..e0e1a224 --- /dev/null +++ b/examples/aot/topk/.gitignore @@ -0,0 +1,4 @@ +caller.cpp +topk_float32.pto +topk_float32.cpp +topk_float32_lib.so diff --git a/examples/aot/topk/README.md b/examples/aot/topk/README.md new file mode 100644 index 00000000..b99e49bd --- /dev/null +++ b/examples/aot/topk/README.md @@ -0,0 +1,77 @@ +# TopK (AOT, dynamic n_rows, float32) + +Finds the top-K largest elements per row of a 2-D `[N_ROWS × N_COLS]` float32 +matrix using a TSORT32 → TMRGSORT → TGATHER pipeline on the NPU vector engine. + +`N_ROWS` is a **runtime** argument — a single compiled `.so` handles any row +count without recompilation. `N_COLS`, `TOPK`, and `BLOCK_DIM` are +compile-time constants because they govern tile buffer sizes and the number of +merge-sort passes, which must be statically known by the hardware. + +## Parameters + +| Symbol | Kind | Default | Meaning | +|------------------|:------------:|--------:|--------------------------------------| +| `N_ROWS` | **runtime** | any | rows in the input matrix | +| `N_COLS` | compile-time | 512 | input elements per row | +| `TOPK` | compile-time | 256 | top-k output count per row | +| `BLOCK_DIM` | compile-time | 24 | number of NPU compute blocks | +| `SORT_BLOCK_LEN` | compile-time | 32 | TSORT32 sorts in blocks of this many | + +Valid `N_COLS` values (with `SORT_BLOCK_LEN=32`): + +| `N_COLS` | `SORT_COLS` | Merge passes | +|---------:|------------:|:------------:| +| 128 | 256 | 1 | +| 512 | 1024 | 2 | +| 2048 | 4096 | 3 | + +## Pipeline (per row) + +``` +input row [1 x N_COLS] --> TSORT32 --> sort buffer [1 x 2*N_COLS] + (interleaved score/idx pairs) + TMRGSORT x passes --> fully sorted [1 x 2*N_COLS] + TMOV (gather window, valid=[1 x 2*TOPK]) + TGATHER P0101 --> tb_scores [1 x TOPK] float32 + TGATHER P1010 --> tb_indices [1 x TOPK] uint32 +``` + +The gather-window tile has `valid_shape=[1, 2*TOPK]`, which limits TGATHER +to exactly `TOPK` outputs even when `TOPK < N_COLS`. + +## Usage + +Compile all configs and validate all 11 test cases: + +```bash +python ./run_topk.py +``` + +To compile a single config manually or skip recompilation: + +```text +# compile one config: N_COLS TOPK BLOCK_DIM +bash ./compile.sh 512 256 24 + +# skip recompilation if .so files already exist +python ./run_topk.py --no-compile +``` + +## Files + +| File | Purpose | +|-------------------|------------------------------------------------------------| +| `topk_builder.py` | PTO-DSL builder – emits MLIR for a given `(N_COLS, TOPK)` | +| `caller.py` | Generates `caller.cpp` with `int32_t n_rows` at call time | +| `compile.sh` | End-to-end build: PTO → MLIR → C++ → `.so` | +| `run_topk.py` | Validates 11 configs against `torch.topk` | + +Generated build artifacts (gitignored): + +| Artifact | Created by | +|----------------------------------|--------------| +| `caller.cpp` | `compile.sh` | +| `topk_c_k.pto` | `compile.sh` | +| `topk_c_k.cpp` | `compile.sh` | +| `topk_c_k_lib.so` | `compile.sh` | diff --git a/examples/aot/topk/caller.py b/examples/aot/topk/caller.py new file mode 100644 index 00000000..7d0f4f12 --- /dev/null +++ b/examples/aot/topk/caller.py @@ -0,0 +1,48 @@ +"""Generate caller.cpp for a given TopK kernel function name. + +The generated file wraps the NPU kernel launch in an ``extern "C"`` function +that can be called from Python via ctypes. + +n_rows is passed at call time as an ``int32_t`` so the same shared library +handles any row count without recompilation. + +Usage +----- + python caller.py topk_c512_k256 + python caller.py topk_c128_k64 --block-dim 24 > caller.cpp +""" + +import argparse + +_DEFAULT_BLOCK_DIM = 24 + + +def generate(fn: str, block_dim: int = _DEFAULT_BLOCK_DIM) -> str: + return f"""\ +// Auto-generated by caller.py – do not edit by hand. +#include "{fn}.cpp" + +extern "C" void call_{fn}( + void *stream, + uint8_t *src, + uint8_t *inIdx, + uint8_t *out_scores, + uint8_t *out_indices, + int32_t n_rows) +{{ + {fn}<<<{block_dim}, nullptr, stream>>>( + reinterpret_cast(src), + reinterpret_cast(inIdx), + reinterpret_cast(out_scores), + reinterpret_cast(out_indices), + n_rows); +}} +""" + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("fn_name", help="e.g. topk_c512_k256") + parser.add_argument("--block-dim", type=int, default=_DEFAULT_BLOCK_DIM) + args = parser.parse_args() + print(generate(args.fn_name, args.block_dim), end="") diff --git a/examples/aot/topk/compile.sh b/examples/aot/topk/compile.sh new file mode 100755 index 00000000..d1ae0db4 --- /dev/null +++ b/examples/aot/topk/compile.sh @@ -0,0 +1,43 @@ +#!/bin/bash +# Compile one TopK kernel config into a shared library. +# +# Usage: bash compile.sh [N_COLS] [TOPK] [BLOCK_DIM] +# Defaults: 512 256 24 +# +# N_ROWS is a runtime argument – the same library handles any row count. +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +N_COLS=${1:-512} +TOPK=${2:-256} +BLOCK_DIM=${3:-24} + +FN="topk_c${N_COLS}_k${TOPK}" + +TMP=$(mktemp -d) +trap "rm -rf $TMP" EXIT + +python "$SCRIPT_DIR/topk_builder.py" \ + --n-cols "$N_COLS" --topk "$TOPK" --block-dim "$BLOCK_DIM" \ + > "$TMP/${FN}.pto" +ptoas --enable-insert-sync "$TMP/${FN}.pto" -o "$TMP/${FN}.cpp" + +python "$SCRIPT_DIR/caller.py" "$FN" --block-dim "$BLOCK_DIM" > "$TMP/caller.cpp" + +bisheng \ + -I${ASCEND_TOOLKIT_HOME}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + "$TMP/caller.cpp" \ + -o "$SCRIPT_DIR/${FN}_lib.so" + +echo "Built ${FN}_lib.so successfully." diff --git a/examples/aot/topk/run_topk.py b/examples/aot/topk/run_topk.py new file mode 100644 index 00000000..edf7e83c --- /dev/null +++ b/examples/aot/topk/run_topk.py @@ -0,0 +1,161 @@ +""" +Run and validate the TopK AOT kernel for multiple configurations. + +The kernel is **dynamic** in n_rows: a single compiled .so handles any row +count. Configs that share the same (n_cols, topk) pair reuse the same library. + +Usage: + python ./run_topk.py # compile + run all configs + python ./run_topk.py --no-compile # skip recompilation (libs already built) + +Valid N_COLS values (SORT_BLOCK_LEN=32) +--------------------------------------- + SORT_COLS = N_COLS*2 must be a power-of-4 multiple of HW_BLOCK_LEN=64: + N_COLS = 128 → 1 merge pass + N_COLS = 512 → 2 merge passes + N_COLS = 2048 → 3 merge passes +""" + +import argparse +import ctypes +import os +import subprocess + +import torch +import torch_npu + +from ptodsl.test_util import get_test_device +from topk_builder import fn_name + +_DIR = os.path.dirname(os.path.abspath(__file__)) + +# ── test configurations ─────────────────────────────────────────────────────── +# (n_rows, n_cols, topk, description) +# n_rows can be any positive integer – divisibility by BLOCK_DIM is NOT required. +# Configs sharing the same (n_cols, topk) reuse the same compiled .so. +_CONFIGS = [ + # 1 merge pass – topk < n_cols + (24, 128, 64, "n_rows=24, 1 pass, topk str: + return os.path.join(_DIR, f"{fn_name(n_cols, topk)}_lib.so") + + +def _compile(n_cols: int, topk: int) -> None: + subprocess.check_call( + ["bash", os.path.join(_DIR, "compile.sh"), str(n_cols), str(topk)], + cwd=_DIR, + ) + + +def _load_fn(n_cols: int, topk: int): + lib = ctypes.CDLL(_lib_path(n_cols, topk)) + fn = getattr(lib, f"call_{fn_name(n_cols, topk)}") + fn.argtypes = [ + ctypes.c_void_p, # stream + ctypes.c_void_p, # src [n_rows, n_cols] float32 + ctypes.c_void_p, # inIdx [n_cols] uint32 + ctypes.c_void_p, # out_scores [n_rows, topk] float32 + ctypes.c_void_p, # out_indices [n_rows, topk] uint32 + ctypes.c_int32, # n_rows (runtime) + ] + fn.restype = None + return fn + + +def _ptr(t: torch.Tensor) -> ctypes.c_void_p: + return ctypes.c_void_p(t.data_ptr()) + + +def _run_one(device: str, n_rows: int, n_cols: int, topk: int, desc: str) -> None: + fn = _load_fn(n_cols, topk) + torch.manual_seed(0) + + src = torch.rand(n_rows, n_cols, dtype=torch.float32, device=device) + inidx = torch.arange(n_cols, dtype=torch.int32, device=device) + out_scores = torch.empty(n_rows, topk, dtype=torch.float32, device=device) + out_indices = torch.empty(n_rows, topk, dtype=torch.int32, device=device) + + stream_ptr = torch.npu.current_stream()._as_parameter_ + torch.npu.synchronize() + fn( + stream_ptr, + _ptr(src), + _ptr(inidx), + _ptr(out_scores), + _ptr(out_indices), + ctypes.c_int32(n_rows), + ) + torch.npu.synchronize() + + src_cpu = src.cpu() + + # 1. Scores must exactly match torch.topk (descending, sorted). + ref_vals, _ = torch.topk(src_cpu, topk, dim=-1, largest=True, sorted=True) + torch.testing.assert_close( + out_scores.cpu(), + ref_vals, + rtol=0, + atol=0, + msg=f"scores mismatch ({desc})", + ) + + # 2. Each returned index must point to the correct value in the source row. + # (Don't compare indices directly – hardware may break ties differently.) + gathered = torch.gather(src_cpu, 1, out_indices.cpu().to(torch.int64)) + torch.testing.assert_close( + gathered, + out_scores.cpu(), + rtol=0, + atol=0, + msg=f"index↔score mismatch ({desc})", + ) + + print(f" PASSED {n_rows:5d}×{n_cols:5d} → top-{topk:5d} [{desc}]") + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--no-compile", + action="store_true", + help="skip recompilation (assume .so files already exist)", + ) + args = parser.parse_args() + + device = get_test_device() + torch.npu.set_device(device) + + print(f"Running {len(_CONFIGS)} TopK configs on {device}") + print("-" * 70) + + compiled: set = set() + for n_rows, n_cols, topk, desc in _CONFIGS: + if not args.no_compile and (n_cols, topk) not in compiled: + _compile(n_cols, topk) + compiled.add((n_cols, topk)) + _run_one(device, n_rows, n_cols, topk, desc) + + print("-" * 70) + print(f"All {len(_CONFIGS)} configs PASSED.") + + +if __name__ == "__main__": + main() diff --git a/examples/aot/topk/topk_builder.py b/examples/aot/topk/topk_builder.py new file mode 100644 index 00000000..18de316a --- /dev/null +++ b/examples/aot/topk/topk_builder.py @@ -0,0 +1,288 @@ +""" +TopK AOT kernel: for each row of an [N_ROWS × N_COLS] float32 matrix, find the +top-TOPK elements and return their values and original column indices. + +Pipeline (per row) +------------------ + 1. TSORT32 – sort within SORT_BLOCK_LEN-element blocks, writing interleaved + (score_f32, idx_u32) pairs to the sort buffer. + 2. TMRGSORT – multi-pass 4-way merge until the sort buffer is fully sorted + descending by score. Unrolled at builder time (static sizes). + 3. TMOV tb_sort → tb_gather_win (valid_shape=[1, 2*TOPK]). + 4. TGATHER P0101 on tb_gather_win – extract top-TOPK scores (even slots). + 5. TGATHER P1010 on tb_gather_win – extract top-TOPK indices (odd slots, + stored as uint32 bit-patterns in a float32 tile). + 6. TSTORE – write scores and indices to global memory. + +The gather-window tile has the same physical shape as the sort buffer but its +valid_shape is limited to [1, 2*TOPK]. This ensures TGATHER P0101/P1010 sees +exactly 2*TOPK elements and produces exactly TOPK outputs, even when TOPK < N_COLS +(without the window, P0101 on a sort_cols-element tile would produce N_COLS +outputs and overflow the TOPK-element destination tile). + +Constraints (verified by assertions in build_topk) +--------------------------------------------------- + * TOPK must be ≤ N_COLS. + * N_ROWS is unconstrained at compile time (any value works at runtime). + * HW_BLOCK_LEN (= SORT_BLOCK_LEN × DST_STRIDE) must be a multiple of 64. + * SORT_COLS (= N_COLS × DST_STRIDE) must be an exact power-of-4 multiple of + HW_BLOCK_LEN (guarantees a clean merge with no tail block). + +Valid N_COLS values (SORT_BLOCK_LEN=32) +--------------------------------------- + SORT_COLS = N_COLS*2 must be a power-of-4 multiple of HW_BLOCK_LEN=64: + N_COLS = 128 → SORT_COLS = 256 (1 merge pass) + N_COLS = 512 → SORT_COLS = 1024 (2 merge passes) + N_COLS = 2048 → SORT_COLS = 4096 (3 merge passes) + +Usage +----- + python topk_builder.py # default: n_cols=512 topk=256 + python topk_builder.py --n-cols 128 --topk 64 +""" + +import argparse + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +# float32: TSORT32 expands each input element to (score_f32, idx_u32) = 2 words. +_DST_STRIDE = 2 +_SORT_BLOCK_LEN = 32 # TSORT32 sorts within blocks of this many input elements + + +def fn_name(n_cols: int, topk: int) -> str: + """Unique kernel name (n_rows is dynamic and not encoded).""" + return f"topk_c{n_cols}_k{topk}" + + +def build_topk( + n_cols: int = 512, + topk: int = 256, + block_dim: int = 24, + sort_block_len: int = _SORT_BLOCK_LEN, +): + """Return a compiled MLIR module for the given compile-time TopK shape. + + n_rows is NOT a compile-time parameter – it is a runtime ``int32`` argument + (``argN``) passed at each invocation. The kernel uses ``s.ceil_div`` to + distribute rows across blocks and guards the last block with ``if_context``, + so any n_rows value is supported without recompilation. + """ + sort_cols = n_cols * _DST_STRIDE + hw_block_len = sort_block_len * _DST_STRIDE + + assert topk <= n_cols, f"topk={topk} must be ≤ n_cols={n_cols}" + assert ( + hw_block_len % 64 == 0 + ), f"hw_block_len={hw_block_len} must be a multiple of 64" + _blk = hw_block_len + while _blk * 4 <= sort_cols: + _blk *= 4 + assert _blk == sort_cols, ( + f"sort_cols={sort_cols} is not a power-of-4 multiple of hw_block_len={hw_block_len}; " + "tail merging is not implemented in this example." + ) + + def _meta_data(): + f32 = pto.float32 + u32 = pto.uint32 + tile_cfg = pto.TileBufConfig() + return { + "ptr_f32": pto.PtrType(f32), + "ptr_u32": pto.PtrType(u32), + "index_dtype": pto.int32, + "tensor_src": pto.TensorType(rank=2, dtype=f32), + "tensor_inidx": pto.TensorType(rank=2, dtype=u32), + "tensor_scores": pto.TensorType(rank=2, dtype=f32), + "tensor_indices": pto.TensorType(rank=2, dtype=u32), + "sub_src": pto.SubTensorType(shape=[1, n_cols], dtype=f32), + "sub_inidx": pto.SubTensorType(shape=[1, n_cols], dtype=u32), + "sub_scores": pto.SubTensorType(shape=[1, topk], dtype=f32), + "sub_indices": pto.SubTensorType(shape=[1, topk], dtype=u32), + "tile_src": pto.TileBufType( + shape=[1, n_cols], + valid_shape=[1, n_cols], + dtype=f32, + memory_space="VEC", + config=tile_cfg, + ), + "tile_inidx": pto.TileBufType( + shape=[1, n_cols], + valid_shape=[1, n_cols], + dtype=u32, + memory_space="VEC", + config=tile_cfg, + ), + "tile_sort_f32": pto.TileBufType( + shape=[1, sort_cols], + valid_shape=[1, sort_cols], + dtype=f32, + memory_space="VEC", + config=tile_cfg, + ), + "tile_sort_u32": pto.TileBufType( + shape=[1, sort_cols], + valid_shape=[1, sort_cols], + dtype=u32, + memory_space="VEC", + config=tile_cfg, + ), + # Gather window: same physical shape as tile_sort, but valid_shape + # limited to [1, 2*topk] so TGATHER P0101/P1010 produces topk outputs. + "tile_gather_win_f32": pto.TileBufType( + shape=[1, sort_cols], + valid_shape=[1, 2 * topk], + dtype=f32, + memory_space="VEC", + config=tile_cfg, + ), + "tile_gather_win_u32": pto.TileBufType( + shape=[1, sort_cols], + valid_shape=[1, 2 * topk], + dtype=u32, + memory_space="VEC", + config=tile_cfg, + ), + "tile_topk_f32": pto.TileBufType( + shape=[1, topk], + valid_shape=[1, topk], + dtype=f32, + memory_space="VEC", + config=tile_cfg, + ), + "tile_topk_u32": pto.TileBufType( + shape=[1, topk], + valid_shape=[1, topk], + dtype=u32, + memory_space="VEC", + config=tile_cfg, + ), + } + + def _kernel( + src_ptr: "ptr_f32", # [n_rows, n_cols] float32 – input scores + inidx_ptr: "ptr_u32", # [n_cols] uint32 – original column indices + scores_ptr: "ptr_f32", # [n_rows, topk] float32 – output top-k scores + indices_ptr: "ptr_u32", # [n_rows, topk] uint32 – output top-k indices + argN: "index_dtype", # n_rows (runtime) + ) -> None: + c0 = const(0) + c1 = const(1) + c_ncols = const(n_cols) + c_topk = const(topk) + c_bdim = const(block_dim) + + n_rows_dyn = s.index_cast(argN) + bid = s.index_cast(pto.get_block_idx()) + + # Distribute rows across blocks with ceil_div – works for any n_rows. + rows_per_core = s.ceil_div(n_rows_dyn, c_bdim) + row_start = bid * rows_per_core + row_end_raw = row_start + rows_per_core + need_clamp = row_end_raw > n_rows_dyn + rows_this_core = s.select(need_clamp, n_rows_dyn - row_start, rows_per_core) + + with pto.vector_section(): + tv_src = pto.as_tensor( + tensor_src, + ptr=src_ptr, + shape=[n_rows_dyn, c_ncols], + strides=[c_ncols, c1], + ) + tv_inidx = pto.as_tensor( + tensor_inidx, ptr=inidx_ptr, shape=[c1, c_ncols], strides=[c_ncols, c1] + ) + tv_scores = pto.as_tensor( + tensor_scores, + ptr=scores_ptr, + shape=[n_rows_dyn, c_topk], + strides=[c_topk, c1], + ) + tv_indices = pto.as_tensor( + tensor_indices, + ptr=indices_ptr, + shape=[n_rows_dyn, c_topk], + strides=[c_topk, c1], + ) + + tb_src = pto.alloc_tile(tile_src) + tb_inidx = pto.alloc_tile(tile_inidx) + tb_sort = pto.alloc_tile(tile_sort_f32) + tb_sort_tmp = pto.alloc_tile(tile_sort_f32) + tb_gather_win_f = pto.alloc_tile(tile_gather_win_f32) + tb_gather_win_u = pto.alloc_tile(tile_gather_win_u32) + tb_scores = pto.alloc_tile(tile_topk_f32) + tb_indices = pto.alloc_tile(tile_topk_u32) + + # Load shared column-index vector once per core. + sv_inidx = pto.slice_view( + sub_inidx, source=tv_inidx, offsets=[c0, c0], sizes=[c1, c_ncols] + ) + pto.load(sv_inidx, tb_inidx) + + # Guard: blocks beyond n_rows do nothing. + with pto.if_context(row_start < n_rows_dyn): + with pto.if_context(rows_this_core > c0): + for i in pto.range(c0, rows_this_core, c1): + row = i + row_start + + # 1. Load input row. + sv_src = pto.slice_view( + sub_src, + source=tv_src, + offsets=[row, c0], + sizes=[c1, c_ncols], + ) + pto.load(sv_src, tb_src) + + # 2. TSORT32: sort within sort_block_len-element blocks. + tile.sort32(tb_src, tb_sort, tb_inidx) + + # 3. Multi-pass TMRGSORT (unrolled at build time). + cur_block = hw_block_len + while cur_block * 4 <= sort_cols: + tile.mrgsort(tb_sort, tb_sort_tmp, const(cur_block)) + tile.mov(tb_sort_tmp, tb_sort) + cur_block *= 4 + + # 4. Copy into gather window (valid_shape=[1, 2*topk]). + tile.mov(tb_sort, tb_gather_win_f) + + # 5. Extract top-topk scores (even slots = score_f32). + tile.gather(tb_gather_win_f, tb_scores, mask_pattern="P0101") + + # 6. Extract top-topk indices (odd slots = idx_u32 bits). + tile.mov(tb_sort, tb_gather_win_u) + tile.gather(tb_gather_win_u, tb_indices, mask_pattern="P1010") + + # 7. Store outputs. + sv_scores = pto.slice_view( + sub_scores, + source=tv_scores, + offsets=[row, c0], + sizes=[c1, c_topk], + ) + pto.store(tb_scores, sv_scores) + + sv_indices = pto.slice_view( + sub_indices, + source=tv_indices, + offsets=[row, c0], + sizes=[c1, c_topk], + ) + pto.store(tb_indices, sv_indices) + + _kernel.__name__ = fn_name(n_cols, topk) + return to_ir_module(meta_data=_meta_data)(_kernel) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Print MLIR IR for a TopK kernel") + parser.add_argument("--n-cols", type=int, default=512) + parser.add_argument("--topk", type=int, default=256) + parser.add_argument("--block-dim", type=int, default=24) + args = parser.parse_args() + print(build_topk(n_cols=args.n_cols, topk=args.topk, block_dim=args.block_dim)) diff --git a/examples/jit/add_dynamic_multicore/run_add.py b/examples/jit/add_dynamic_multicore/run_add.py index aca0d9a1..d9ece92b 100644 --- a/examples/jit/add_dynamic_multicore/run_add.py +++ b/examples/jit/add_dynamic_multicore/run_add.py @@ -1,10 +1,10 @@ -from ptodsl import jit -import ptodsl.language as pto +from ptodsl import jit, pto, tile +from ptodsl import scalar as s import torch import torch_npu from ptodsl.test_util import get_test_device -const = pto.const +const = s.const def meta_data(): @@ -51,12 +51,12 @@ def vec_add_1d_dynamic( num_blocks = pto.get_block_num() # Convert i64/i32 values to index for arithmetic ops. - vid_idx = pto.index_cast(vid) - num_cores = pto.index_cast(num_blocks) - total_elements = pto.index_cast(argN) + vid_idx = s.index_cast(vid) + num_cores = s.index_cast(num_blocks) + total_elements = s.index_cast(argN) - num_tiles_global = pto.ceil_div(total_elements, c_tile) - num_tiles_per_core = pto.ceil_div(num_tiles_global, num_cores) + num_tiles_global = s.ceil_div(total_elements, c_tile) + num_tiles_per_core = s.ceil_div(num_tiles_global, num_cores) tile_offset_this_core = vid_idx * num_tiles_per_core with pto.vector_section(): @@ -74,29 +74,38 @@ def vec_add_1d_dynamic( need_truncate = tiles_end_this_core > num_tiles_global remaining_tiles = num_tiles_global - tile_offset_this_core - tiles_to_process = pto.select( + tiles_to_process = s.select( need_truncate, remaining_tiles, num_tiles_per_core ) elements_to_process = tiles_to_process * c_tile with pto.if_context(elements_to_process > c0): - for i in pto.for_range(c0, tiles_to_process, c1): + for i in pto.range(c0, tiles_to_process, c1): tile_offset_global = i + tile_offset_this_core offset_global = tile_offset_global * c_tile sv0 = pto.slice_view( - subtensor_type, source=tv0, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv0, + offsets=[offset_global], + sizes=[c_tile], ) sv1 = pto.slice_view( - subtensor_type, source=tv1, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv1, + offsets=[offset_global], + sizes=[c_tile], ) sv2 = pto.slice_view( - subtensor_type, source=tv2, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv2, + offsets=[offset_global], + sizes=[c_tile], ) pto.load(sv0, tb0) pto.load(sv1, tb1) - pto.add(tb0, tb1, tb2) + tile.add(tb0, tb1, tb2) pto.store(tb2, sv2) @@ -109,7 +118,14 @@ def test_add(): tile_size = 1024 # Keep shapes aligned to tile size, but vary tile counts so they are not # required to be multiples of `num_cores`. - tile_counts = [1, 7, num_cores - 1, num_cores + 3, 2 * num_cores + 7, 5 * num_cores - 5] + tile_counts = [ + 1, + 7, + num_cores - 1, + num_cores + 3, + 2 * num_cores + 7, + 5 * num_cores - 5, + ] shape_list = [tile_size * tiles for tiles in tile_counts] torch.manual_seed(0) diff --git a/examples/jit/add_static_multicore/run_add_1d.py b/examples/jit/add_static_multicore/run_add_1d.py index 722fe24c..b47a0c15 100644 --- a/examples/jit/add_static_multicore/run_add_1d.py +++ b/examples/jit/add_static_multicore/run_add_1d.py @@ -1,10 +1,10 @@ -from ptodsl import jit -import ptodsl.language as pto +from ptodsl import jit, pto, tile +from ptodsl import scalar as s import torch import torch_npu from ptodsl.test_util import get_test_device -const = pto.const +const = s.const def meta_data(): @@ -13,9 +13,14 @@ def meta_data(): ptr_type = pto.PtrType(dtype) tensor_type = pto.TensorType(rank=2, dtype=dtype) - subtensor_type = pto.SubTensorType(shape=[1, 1024], dtype=dtype) + subtensor_type = pto.SubTensorType(shape=[1, 1024], dtype=dtype) tile_type = pto.TileBufType( - shape=[1, 1024], valid_shape=[-1, -1], dtype=dtype, memory_space="VEC", config=pto.TileBufConfig()) + shape=[1, 1024], + valid_shape=[-1, -1], + dtype=dtype, + memory_space="VEC", + config=pto.TileBufConfig(), + ) return { "ptr_type": ptr_type, @@ -26,11 +31,7 @@ def meta_data(): @jit(meta_data=meta_data, block_dim=20) -def vec_add_kernel( - arg0: "ptr_type", - arg1: "ptr_type", - arg2: "ptr_type" - ) -> None: +def vec_add_kernel(arg0: "ptr_type", arg1: "ptr_type", arg2: "ptr_type") -> None: c0 = const(0) c1 = const(1) c1024 = const(1024) @@ -46,11 +47,17 @@ def vec_add_kernel( tv1 = pto.as_tensor(tensor_type, ptr=arg1, shape=[c1, c1024], strides=[c1024, c1]) tv2 = pto.as_tensor(tensor_type, ptr=arg2, shape=[c1, c1024], strides=[c1024, c1]) - vid_idx = pto.index_cast(vid) + vid_idx = s.index_cast(vid) offset = vid_idx * c1024 # every core loads 1024 elements of data - sv0 = pto.slice_view(subtensor_type, source=tv0, offsets=[c0, offset], sizes=[c1, c1024]) - sv1 = pto.slice_view(subtensor_type, source=tv1, offsets=[c0, offset], sizes=[c1, c1024]) - sv2 = pto.slice_view(subtensor_type, source=tv2, offsets=[c0, offset], sizes=[c1, c1024]) + sv0 = pto.slice_view( + subtensor_type, source=tv0, offsets=[c0, offset], sizes=[c1, c1024] + ) + sv1 = pto.slice_view( + subtensor_type, source=tv1, offsets=[c0, offset], sizes=[c1, c1024] + ) + sv2 = pto.slice_view( + subtensor_type, source=tv2, offsets=[c0, offset], sizes=[c1, c1024] + ) with pto.vector_section(): tb0 = pto.alloc_tile(tile_type, valid_row=c1, valid_col=c1024) @@ -59,7 +66,7 @@ def vec_add_kernel( pto.load(sv0, tb0) pto.load(sv1, tb1) - pto.add(tb0, tb1, tb2) + tile.add(tb0, tb1, tb2) pto.store(tb2, sv2) @@ -81,5 +88,6 @@ def test_add(): torch.testing.assert_close(z, z_ref) print("result equal!") + if __name__ == "__main__": test_add() diff --git a/examples/jit/add_static_multicore/run_add_2d.py b/examples/jit/add_static_multicore/run_add_2d.py index ca73fe23..663c7a91 100644 --- a/examples/jit/add_static_multicore/run_add_2d.py +++ b/examples/jit/add_static_multicore/run_add_2d.py @@ -1,10 +1,10 @@ -from ptodsl import jit -import ptodsl.language as pto +from ptodsl import jit, pto, tile +from ptodsl import scalar as s import torch import torch_npu from ptodsl.test_util import get_test_device -const = pto.const +const = s.const def meta_data(): @@ -13,11 +13,18 @@ def meta_data(): index_dtype = pto.int32 ptr_type = pto.PtrType(dtype) tensor_type = pto.TensorType(rank=2, dtype=dtype) - subtensor_type = pto.SubTensorType(shape=[32, 32], dtype=dtype) # TODO: omit shape https://github.com/zhangstevenunity/PTOAS/issues/31 + subtensor_type = pto.SubTensorType( + shape=[32, 32], dtype=dtype + ) # TODO: omit shape https://github.com/zhangstevenunity/PTOAS/issues/31 tile_cfg = pto.TileBufConfig() # defaults to pto.TileBufConfig(blayout="RowMajor", slayout="NoneBox", s_fractal_size=512, pad="Null") tile_type = pto.TileBufType( - shape=[32, 32], valid_shape=[-1, -1], dtype=dtype, memory_space="VEC", config=tile_cfg) + shape=[32, 32], + valid_shape=[-1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) return { "ptr_type": ptr_type, "index_dtype": index_dtype, @@ -33,8 +40,8 @@ def vec_add_kernel( arg1: "ptr_type", arg2: "ptr_type", vrow: "index_dtype", - vcol: "index_dtype" - ) -> None: + vcol: "index_dtype", +) -> None: c0 = const(0) c1 = const(1) c32 = const(32) @@ -46,18 +53,24 @@ def vec_add_kernel( cidmul = cid * sub_bnum vid = cidmul + sub_bid - v_row_idx = pto.index_cast(vrow) - v_col_idx = pto.index_cast(vcol) + v_row_idx = s.index_cast(vrow) + v_col_idx = s.index_cast(vcol) tv0 = pto.as_tensor(tensor_type, ptr=arg0, shape=[c1280, c32], strides=[c32, c1]) tv1 = pto.as_tensor(tensor_type, ptr=arg1, shape=[c1280, c32], strides=[c32, c1]) tv2 = pto.as_tensor(tensor_type, ptr=arg2, shape=[c1280, c32], strides=[c32, c1]) - vid_idx = pto.index_cast(vid) + vid_idx = s.index_cast(vid) offset_row = vid_idx * c32 # every core loads 32 rows of data - sv0 = pto.slice_view(subtensor_type, source=tv0, offsets=[offset_row, c0], sizes=[c32, c32]) - sv1 = pto.slice_view(subtensor_type, source=tv1, offsets=[offset_row, c0], sizes=[c32, c32]) - sv2 = pto.slice_view(subtensor_type, source=tv2, offsets=[offset_row, c0], sizes=[c32, c32]) + sv0 = pto.slice_view( + subtensor_type, source=tv0, offsets=[offset_row, c0], sizes=[c32, c32] + ) + sv1 = pto.slice_view( + subtensor_type, source=tv1, offsets=[offset_row, c0], sizes=[c32, c32] + ) + sv2 = pto.slice_view( + subtensor_type, source=tv2, offsets=[offset_row, c0], sizes=[c32, c32] + ) with pto.vector_section(): tb0 = pto.alloc_tile(tile_type, valid_row=v_row_idx, valid_col=v_col_idx) @@ -66,7 +79,7 @@ def vec_add_kernel( pto.load(sv0, tb0) pto.load(sv1, tb1) - pto.add(tb0, tb1, tb2) + tile.add(tb0, tb1, tb2) pto.store(tb2, sv2) @@ -88,5 +101,6 @@ def test_add(): torch.testing.assert_close(z, z_ref) print("result equal!") + if __name__ == "__main__": test_add() diff --git a/examples/jit/matmul_dynamic_multicore/run_batch_matmul.py b/examples/jit/matmul_dynamic_multicore/run_batch_matmul.py index 1f75cbc3..0f668aa9 100644 --- a/examples/jit/matmul_dynamic_multicore/run_batch_matmul.py +++ b/examples/jit/matmul_dynamic_multicore/run_batch_matmul.py @@ -1,12 +1,12 @@ from mlir.ir import IntegerType -from ptodsl import jit -import ptodsl.language as pto +from ptodsl import jit, pto, tile +from ptodsl import scalar as s import torch import torch_npu from ptodsl.test_util import get_test_device -const = pto.const +const = s.const def build_kernel( @@ -34,14 +34,26 @@ def meta_data(): tile_view_out = pto.SubTensorType(shape=[M, N], dtype=dtype) tile_view_bias = pto.SubTensorType(shape=[1, N], dtype=dtype) - tile_buf_aMat = pto.TileBufType(shape=[M, BASEK], dtype=dtype, memory_space="MAT") - tile_buf_bMat = pto.TileBufType(shape=[BASEK, N], dtype=dtype, memory_space="MAT") - tile_buf_biasData = pto.TileBufType(shape=[1, N], dtype=dtype, memory_space="MAT") - - tile_buf_aTile = pto.TileBufType(shape=[M, BASEK], dtype=dtype, memory_space="LEFT") - tile_buf_bTile = pto.TileBufType(shape=[BASEK, N], dtype=dtype, memory_space="RIGHT") + tile_buf_aMat = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="MAT" + ) + tile_buf_bMat = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="MAT" + ) + tile_buf_biasData = pto.TileBufType( + shape=[1, N], dtype=dtype, memory_space="MAT" + ) + + tile_buf_aTile = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="LEFT" + ) + tile_buf_bTile = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="RIGHT" + ) tile_buf_cTile = pto.TileBufType(shape=[M, N], dtype=dtype, memory_space="ACC") - tile_buf_biasTile = pto.TileBufType(shape=[1, N], dtype=dtype, memory_space="BIAS") + tile_buf_biasTile = pto.TileBufType( + shape=[1, N], dtype=dtype, memory_space="BIAS" + ) return { "ptr_type": ptr_dtype, @@ -81,20 +93,28 @@ def RunTMATMULSplitK( cTileM = const(M) cTileN = const(N) - batch = pto.index_cast(batch_i32) + batch = s.index_cast(batch_i32) cBM = batch * cM - num_blocks = pto.index_cast(pto.get_block_num()) - batches_per_core = pto.ceil_div(batch, num_blocks) - bid = pto.index_cast(pto.get_block_idx()) + num_blocks = s.index_cast(pto.get_block_num()) + batches_per_core = s.ceil_div(batch, num_blocks) + bid = s.index_cast(pto.get_block_idx()) b_start = bid * batches_per_core b_end_unclamped = b_start + batches_per_core - b_end = pto.min_u(b_end_unclamped, batch) - - tvA = pto.as_tensor(tensor_type, ptr=a_ptr, shape=[cBM, cK], strides=[cK, c1]) - tvB = pto.as_tensor(tensor_type, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1]) - tvOut = pto.as_tensor(tensor_type, ptr=out_ptr, shape=[cBM, cN], strides=[cN, c1]) - tvBias = pto.as_tensor(tensor_type, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1]) + b_end = s.min_u(b_end_unclamped, batch) + + tvA = pto.as_tensor( + tensor_type, ptr=a_ptr, shape=[cBM, cK], strides=[cK, c1] + ) + tvB = pto.as_tensor( + tensor_type, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1] + ) + tvOut = pto.as_tensor( + tensor_type, ptr=out_ptr, shape=[cBM, cN], strides=[cN, c1] + ) + tvBias = pto.as_tensor( + tensor_type, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1] + ) aMatTile = pto.alloc_tile(tile_buf_aMat) bMatTile = pto.alloc_tile(tile_buf_bMat) @@ -104,10 +124,10 @@ def RunTMATMULSplitK( cTile = pto.alloc_tile(tile_buf_cTile) biasTile = pto.alloc_tile(tile_buf_biasTile) - for b_idx in pto.for_range(b_start, b_end, c1): + for b_idx in pto.range(b_start, b_end, c1): row_off = b_idx * cM - for i in pto.for_range(c0, cIter, c1): + for i in pto.range(c0, cIter, c1): kOff = i * cBASEK svA = pto.slice_view( tile_view_a, @@ -135,26 +155,26 @@ def RunTMATMULSplitK( pto.record_wait_pair("LOAD", "MOV_M2L", event_id=0) - pto.mov(aMatTile, aTile) - pto.mov(bMatTile, bTile) + tile.mov(aMatTile, aTile) + tile.mov(bMatTile, bTile) with pto.if_context(isBias): - pto.mov(biasDataTile, biasTile) + tile.mov(biasDataTile, biasTile) pto.record_wait_pair("MOV_M2L", "MATMUL", event_id=0) - is_i0 = pto.eq(i, c0) + is_i0 = s.eq(i, c0) def _first_iter(): pto.cond( isBias, - lambda: pto.matmul_bias(aTile, bTile, biasTile, cTile), - lambda: pto.matmul(aTile, bTile, cTile), + lambda: tile.matmul_bias(aTile, bTile, biasTile, cTile), + lambda: tile.matmul(aTile, bTile, cTile), ) pto.cond( is_i0, _first_iter, - lambda: pto.matmul_acc(cTile, aTile, bTile, cTile), + lambda: tile.matmul_acc(cTile, aTile, bTile, cTile), ) pto.record_wait_pair("MATMUL", "LOAD", event_id=0) diff --git a/examples/ppt/mixed_pto_vector_slide.md b/examples/ppt/mixed_pto_vector_slide.md new file mode 100644 index 00000000..803180dd --- /dev/null +++ b/examples/ppt/mixed_pto_vector_slide.md @@ -0,0 +1,77 @@ +# PTO `t*` + `v*` 混合示例 + +## 一页版表达 + +```text +Outer PTO tile flow: + make_tensor_view -> partition_view -> tload -> [vector inner loop] -> tstore + +Inner vector loop: + vlds -> vlds -> vadd -> vsts +``` + +## PPT 版伪 IR + +```mlir +module { + func.func @vec_add_mixed( + %a: !pto.ptr, + %b: !pto.ptr, + %c: !pto.ptr) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %c64 = arith.constant 64 : index + %c1024 = arith.constant 1024 : index + + // 1) 先用 PTO tile op 选出一个 32x32 工作块 + %A = pto.make_tensor_view %a, shape = [%c32, %c32], strides = [%c32, %c1] + : !pto.tensor_view + %B = pto.make_tensor_view %b, shape = [%c32, %c32], strides = [%c32, %c1] + : !pto.tensor_view + %C = pto.make_tensor_view %c, shape = [%c32, %c32], strides = [%c32, %c1] + : !pto.tensor_view + + %tileA = pto.partition_view %A, offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %tileB = pto.partition_view %B, offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %tileC = pto.partition_view %C, offsets = [%c0, %c0], sizes = [%c32, %c32] + : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + + // 统一记号:!tile 表示 vec-local 32x32 f32 tile_buf + %bufA = pto.alloc_tile : !pto.tile_buf + %bufB = pto.alloc_tile : !pto.tile_buf + %bufC = pto.alloc_tile : !pto.tile_buf + + // 2) tile 级搬运:GM -> local tile + pto.tload ins(%tileA : !pto.partition_tensor_view<32x32xf32>) + outs(%bufA : !pto.tile_buf) + pto.tload ins(%tileB : !pto.partition_tensor_view<32x32xf32>) + outs(%bufB : !pto.tile_buf) + + // 3) vector 级计算:在 local tile 内部按 64-lane 分块 + %ptrA = pto.tile_buf_addr %bufA : !pto.tile_buf<...> -> !llvm.ptr<6> + %ptrB = pto.tile_buf_addr %bufB : !pto.tile_buf<...> -> !llvm.ptr<6> + %ptrC = pto.tile_buf_addr %bufC : !pto.tile_buf<...> -> !llvm.ptr<6> + + scf.for %i = %c0 to %c1024 step %c64 { + %va = pto.vlds %ptrA[%i] : !llvm.ptr<6> -> !pto.vreg<64xf32> + %vb = pto.vlds %ptrB[%i] : !llvm.ptr<6> -> !pto.vreg<64xf32> + %vc = pto.vadd %va, %vb + : !pto.vreg<64xf32>, !pto.vreg<64xf32> -> !pto.vreg<64xf32> + pto.vsts %vc, %ptrC[%i] : !pto.vreg<64xf32>, !llvm.ptr<6> + } + + // 4) tile 级写回:local tile -> GM + pto.tstore ins(%bufC : !pto.tile_buf) + outs(%tileC : !pto.partition_tensor_view<32x32xf32>) + return + } +} +``` + +## 讲解时只强调这两层 + +- `pto.t*` 负责选 tile 和搬 tile:`make_tensor_view -> partition_view -> tload -> tstore` +- `pto.v*` 负责在 tile 内做向量计算:`vlds -> vadd -> vsts` diff --git a/examples/validate_all_examples.py b/examples/validate_all_examples.py index 04521c1f..bd80e4e1 100644 --- a/examples/validate_all_examples.py +++ b/examples/validate_all_examples.py @@ -71,7 +71,9 @@ def extract_commands(readme_path: Path) -> list[str]: return [] -def run_example(example_name: str, readme_path: Path, commands: list[str]) -> ExampleResult: +def run_example( + example_name: str, readme_path: Path, commands: list[str] +) -> ExampleResult: example_start = time.time() if not commands: return ExampleResult( diff --git a/ptodsl/__init__.py b/ptodsl/__init__.py index 7a5a7a3c..5ed02b28 100644 --- a/ptodsl/__init__.py +++ b/ptodsl/__init__.py @@ -1,388 +1,18 @@ -import inspect -import ctypes -import os -import pathlib -import subprocess -from functools import update_wrapper - -from mlir.dialects import func, pto -from mlir.ir import Context, InsertionPoint, Location, Module - -from .language import wrap_value +from . import pto, scalar, tile from .bench import do_bench - - -def _resolve_meta(meta_fn): - values = meta_fn() - if not isinstance(values, dict): - raise ValueError("`meta_data()` must return a dict of named symbols to MLIR/PTO types.") - return dict(values) - - -def _resolve_arg_types(signature, meta_map): - arg_types = [] - for param in signature.parameters.values(): - annot = param.annotation - if isinstance(annot, str): - if annot not in meta_map: - raise ValueError(f"Unknown annotation '{annot}'.") - arg_types.append(meta_map[annot]) - elif annot is inspect._empty: - raise ValueError(f"Missing annotation for argument '{param.name}'.") - else: - arg_types.append(annot) - return arg_types - - -def _resolve_ret_types(signature, meta_map): - ret_annot = signature.return_annotation - if ret_annot in (inspect._empty, None): - return [] - if isinstance(ret_annot, str): - if ret_annot not in meta_map: - raise ValueError(f"Unknown return annotation '{ret_annot}'.") - return [meta_map[ret_annot]] - if isinstance(ret_annot, (list, tuple)): - out = [] - for elem in ret_annot: - if isinstance(elem, str): - out.append(meta_map[elem]) - else: - out.append(elem) - return out - return [ret_annot] - - -def _has_func_return(block): - last_name = None - for op in block.operations: - last_name = op.operation.name - return last_name == "func.return" - - -def _inject_globals(fn, values): - old = {} - for name, value in values.items(): - old[name] = fn.__globals__.get(name, None) - fn.__globals__[name] = value - return old - - -def _restore_globals(fn, old, injected_names): - for name in injected_names: - if old[name] is None and name in fn.__globals__: - del fn.__globals__[name] - else: - fn.__globals__[name] = old[name] - - -def to_ir_module(*, meta_data): - def decorator(fn): - sig = inspect.signature(fn) - - with Context() as ctx, Location.unknown(): - pto.register_dialect(ctx, load=True) - meta_map = _resolve_meta(meta_data) - arg_types = _resolve_arg_types(sig, meta_map) - ret_types = _resolve_ret_types(sig, meta_map) - module = Module.create() - fn_ty = func.FunctionType.get(arg_types, ret_types) - - with InsertionPoint(module.body): - ir_func = func.FuncOp(fn.__name__, fn_ty) - entry = ir_func.add_entry_block() - - with InsertionPoint(entry): - wrapped_args = [wrap_value(arg) for arg in entry.arguments] - injected = set(meta_map.keys()) - old_globals = _inject_globals(fn, meta_map) - try: - fn(*wrapped_args) - finally: - _restore_globals(fn, old_globals, injected) - - if not ret_types and not _has_func_return(entry): - func.ReturnOp([]) - - module.operation.verify() - return module - - return decorator - - -def _type_repr(type_obj): - return str(type_obj).replace(" ", "").lower() - - -def _is_ptr_type(type_obj): - return "ptr" in _type_repr(type_obj) - - -def _ptr_elem_cpp_type(type_obj): - type_repr = _type_repr(type_obj) - if "f32" in type_repr: - return "float" - if "f16" in type_repr: - return "__fp16" - if "bf16" in type_repr: - return "__bf16" - if "i8" in type_repr: - return "int8_t" - if "u8" in type_repr: - return "uint8_t" - if "i16" in type_repr: - return "int16_t" - if "u16" in type_repr: - return "uint16_t" - if "i32" in type_repr: - return "int32_t" - if "u32" in type_repr: - return "uint32_t" - if "i64" in type_repr: - return "int64_t" - if "u64" in type_repr: - return "uint64_t" - return "float" - - -def _scalar_cpp_type(type_obj): - type_repr = _type_repr(type_obj) - if "i32" in type_repr: - return "int32_t" - if "i64" in type_repr or "index" in type_repr: - return "int64_t" - if "f32" in type_repr: - return "float" - if "f16" in type_repr: - return "__fp16" - return "int32_t" - - -def _scalar_ctype(type_obj): - type_repr = _type_repr(type_obj) - if "i64" in type_repr or "index" in type_repr: - return ctypes.c_int64 - if "f32" in type_repr: - return ctypes.c_float - if "f16" in type_repr: - return ctypes.c_uint16 - return ctypes.c_int32 - - -def _normalize_stream_ptr(stream_ptr): - if isinstance(stream_ptr, ctypes.c_void_p): - return stream_ptr - if isinstance(stream_ptr, int): - return ctypes.c_void_p(stream_ptr) - if hasattr(stream_ptr, "value"): - return ctypes.c_void_p(int(stream_ptr.value)) - return stream_ptr - - -class JitWrapper: - def __init__( - self, - fn, - *, - meta_data, - output_dir=None, - block_dim=20, - enable_insert_sync=True, - npu_arch="dav-2201", - ): - self._fn = fn - self._meta_data = meta_data - self._sig = inspect.signature(fn) - self._arg_types = None - self._output_dir = pathlib.Path(output_dir) if output_dir else pathlib.Path.cwd() / ".ptodsl_jit" / fn.__name__ - self._block_dim = block_dim - self._enable_insert_sync = enable_insert_sync - self._npu_arch = npu_arch - self._compiled = False - self._lib = None - self._lib_path = self._output_dir / "kernel.so" - update_wrapper(self, fn) - - def _artifact_paths(self): - pto_path = self._output_dir / "kernel.pto" - cpp_path = self._output_dir / "kernel.cpp" - caller_path = self._output_dir / "caller.cpp" - return pto_path, cpp_path, caller_path, self._lib_path - - def _generate_caller_cpp(self, kernel_cpp_name): - params = list(self._sig.parameters.values()) - cpp_args = [] - launch_args = [] - for param, arg_type in zip(params, self._arg_types): - if _is_ptr_type(arg_type): - cpp_args.append(f"uint8_t *{param.name}") - launch_args.append(f"({ _ptr_elem_cpp_type(arg_type) } *){param.name}") - else: - cpp_t = _scalar_cpp_type(arg_type) - cpp_args.append(f"{cpp_t} {param.name}") - launch_args.append(param.name) - - wrapper_sig = ", ".join(["uint32_t blockDim", "void *stream"] + cpp_args) - kernel_call = ", ".join(launch_args) - return ( - f'#include "{kernel_cpp_name}"\n' - f"#include \n\n" - f'extern "C" void call_kernel({wrapper_sig})\n' - "{\n" - f" {self._fn.__name__}<<>>({kernel_call});\n" - "}\n" - ) - - def _compile_shared_library(self, caller_cpp_path, lib_path): - toolkit_home = os.environ.get("ASCEND_TOOLKIT_HOME") - if not toolkit_home: - raise RuntimeError("ASCEND_TOOLKIT_HOME is required to compile generated caller.cpp.") - cmd = [ - "bisheng", - f"-I{toolkit_home}/include", - "-fPIC", - "-shared", - "-D_FORTIFY_SOURCE=2", - "-O2", - "-std=c++17", - "-Wno-macro-redefined", - "-Wno-ignored-attributes", - "-fstack-protector-strong", - "-xcce", - "-Xhost-start", - "-Xhost-end", - "-mllvm", - "-cce-aicore-stack-size=0x8000", - "-mllvm", - "-cce-aicore-function-stack-size=0x8000", - "-mllvm", - "-cce-aicore-record-overflow=true", - "-mllvm", - "-cce-aicore-addr-transform", - "-mllvm", - "-cce-aicore-dcci-insert-for-scalar=false", - f"--npu-arch={self._npu_arch}", - "-DMEMORY_BASE", # TODO: add switch for A5 - "-std=gnu++17", - str(caller_cpp_path), - "-o", - str(lib_path), - ] - subprocess.run(cmd, check=True, cwd=str(self._output_dir)) - - def _resolve_runtime_arg_types(self): - with Context() as ctx, Location.unknown(): - pto.register_dialect(ctx, load=True) - meta_map = _resolve_meta(self._meta_data) - return _resolve_arg_types(self._sig, meta_map) - - def _build(self): - self._output_dir.mkdir(parents=True, exist_ok=True) - pto_path, cpp_path, caller_path, lib_path = self._artifact_paths() - self._arg_types = self._resolve_runtime_arg_types() - - ir_module = to_ir_module(meta_data=self._meta_data)(self._fn) - pto_path.write_text(f"{ir_module}\n", encoding="utf-8") - - ptoas_cmd = ["ptoas"] - if self._enable_insert_sync: - ptoas_cmd.append("--enable-insert-sync") - ptoas_cmd += [str(pto_path), "-o", str(cpp_path)] - subprocess.run(ptoas_cmd, check=True, cwd=str(self._output_dir)) - - caller_path.write_text(self._generate_caller_cpp(cpp_path.name), encoding="utf-8") - self._compile_shared_library(caller_path, lib_path) - - self._lib = ctypes.CDLL(str(lib_path)) - self._lib.call_kernel.argtypes = [ctypes.c_uint32, ctypes.c_void_p] + [ - ctypes.c_void_p if _is_ptr_type(arg_type) else _scalar_ctype(arg_type) - for arg_type in self._arg_types - ] - self._compiled = True - - def _convert_ptr(self, value): - if isinstance(value, ctypes.c_void_p): - return value - if hasattr(value, "data_ptr"): - return ctypes.c_void_p(value.data_ptr()) - if isinstance(value, int): - return ctypes.c_void_p(value) - raise TypeError(f"Pointer-like argument expected, got {type(value)!r}.") - - def _prepare_call_args(self, args): - params = list(self._sig.parameters.values()) - if len(args) > len(params): - raise TypeError(f"Expected at most {len(params)} arguments, got {len(args)}.") - - filled_args = list(args) - for idx in range(len(args), len(params)): - param = params[idx] - if param.default is not inspect._empty: - filled_args.append(param.default) - continue - arg_type = self._arg_types[idx] - if _is_ptr_type(arg_type): - raise TypeError(f"Missing required pointer argument '{param.name}'.") - - converted = [] - for value, arg_type in zip(filled_args, self._arg_types): - if _is_ptr_type(arg_type): - converted.append(self._convert_ptr(value)) - else: - converted.append(value) - return converted - - # TODO: also allow taking named `kwargs` - def __call__(self, *args, stream_ptr=None): - if not self._compiled: - self._build() - - if stream_ptr is None: - import torch - stream_ptr = torch.npu.current_stream()._as_parameter_ - - call_args = self._prepare_call_args(args) - self._lib.call_kernel( - ctypes.c_uint32(self._block_dim), - _normalize_stream_ptr(stream_ptr), - *call_args, - ) - return None - - def set_block_dim(self, block_dim): - if not isinstance(block_dim, int) or block_dim <= 0: - raise ValueError("`block_dim` must be a positive integer.") - self._block_dim = block_dim - return self - - @property - def library_path(self): - return str(self._lib_path) - - @property - def output_dir(self): - return str(self._output_dir) - - -def jit( - *, - meta_data, - output_dir=None, - block_dim=1, - enable_insert_sync=True, - npu_arch="dav-2201", -): - def decorator(fn): - return JitWrapper( - fn, - meta_data=meta_data, - output_dir=output_dir, - block_dim=block_dim, - enable_insert_sync=enable_insert_sync, - npu_arch=npu_arch, - ) - - return decorator - - -__all__ = ["JitWrapper", "jit", "to_ir_module", "do_bench"] +from .compiler.ir import to_ir_module +from .compiler.jit import JitWrapper, jit +from .constexpr import Constexpr, const_expr, range_constexpr + +__all__ = [ + "Constexpr", + "JitWrapper", + "const_expr", + "do_bench", + "jit", + "pto", + "range_constexpr", + "scalar", + "tile", + "to_ir_module", +] diff --git a/ptodsl/api/__init__.py b/ptodsl/api/__init__.py new file mode 100644 index 00000000..ca7e01f9 --- /dev/null +++ b/ptodsl/api/__init__.py @@ -0,0 +1,3 @@ +from . import pto, scalar, tile + +__all__ = ["pto", "scalar", "tile"] diff --git a/ptodsl/api/control_flow.py b/ptodsl/api/control_flow.py new file mode 100644 index 00000000..457fade8 --- /dev/null +++ b/ptodsl/api/control_flow.py @@ -0,0 +1,52 @@ +from contextlib import contextmanager + +from mlir.dialects import scf +from mlir.ir import InsertionPoint + +from .scalar import Value, _unwrap + + +def range(start, stop, step): + loop = scf.ForOp(_unwrap(start), _unwrap(stop), _unwrap(step)) + with InsertionPoint(loop.body): + yield Value(loop.induction_variable) + scf.YieldOp([]) + + +class _IfElseBranch: + def __init__(self, if_op): + self._if_op = if_op + + @contextmanager + def else_context(self): + with InsertionPoint(self._if_op.else_block): + yield + scf.YieldOp([]) + + +@contextmanager +def if_context(condition, has_else=False): + if has_else: + op = scf.IfOp(_unwrap(condition), [], hasElse=True) + branch = _IfElseBranch(op) + else: + op = scf.IfOp(_unwrap(condition)) + branch = None + + with InsertionPoint(op.then_block): + yield branch + scf.YieldOp([]) + + +def cond(condition, then_builder, else_builder): + op = scf.IfOp(_unwrap(condition), [], hasElse=True) + with InsertionPoint(op.then_block): + then_builder() + scf.YieldOp([]) + with InsertionPoint(op.else_block): + else_builder() + scf.YieldOp([]) + return op + + +__all__ = ["cond", "range", "if_context"] diff --git a/ptodsl/api/pto.py b/ptodsl/api/pto.py new file mode 100644 index 00000000..f2e2d0ac --- /dev/null +++ b/ptodsl/api/pto.py @@ -0,0 +1,60 @@ +from .control_flow import cond, range, if_context +from .scalar import Value, wrap_value +from .pto_general import ( + alloc_tile, + as_tensor, + cube_section, + get_block_idx, + get_block_num, + get_subblock_idx, + get_subblock_num, + load, + slice_view, + store, + vector_section, + print, +) +from .synchronization import barrier, record_event, record_wait_pair, wait_event +from .type_def import ( + PtrType, + SubTensorType, + TensorType, + TileBufConfig, + TileBufType, + __getattr__, +) + + +__all__ = [ + "Value", + "wrap_value", + "bool", + "float16", + "float32", + "int16", + "int32", + "PtrType", + "TensorType", + "SubTensorType", + "TileBufConfig", + "TileBufType", + "get_block_idx", + "get_subblock_idx", + "get_subblock_num", + "get_block_num", + "as_tensor", + "slice_view", + "vector_section", + "cube_section", + "range", + "if_context", + "cond", + "alloc_tile", + "load", + "store", + "print", + "record_event", + "wait_event", + "record_wait_pair", + "barrier", +] diff --git a/ptodsl/api/pto_general.py b/ptodsl/api/pto_general.py new file mode 100644 index 00000000..c8f649ea --- /dev/null +++ b/ptodsl/api/pto_general.py @@ -0,0 +1,117 @@ +from contextlib import contextmanager + +from mlir.dialects import pto as _pto +from mlir.ir import InsertionPoint + +from .scalar import Value, _unwrap + + +def get_block_idx(): + return Value(_pto.GetBlockIdxOp().result) + + +def get_subblock_idx(): + return Value(_pto.GetSubBlockIdxOp().result) + + +def get_subblock_num(): + return Value(_pto.GetSubBlockNumOp().result) + + +def get_block_num(): + return Value(_pto.GetBlockNumOp().result) + + +def _resolve_layout_attr(layout): + if layout is None: + return None + if isinstance(layout, str): + return _pto.LayoutAttr.get(getattr(_pto.Layout, layout)) + return layout + + +def as_tensor(tensor_type, *, ptr, shape, strides, layout=None): + shape_vals = [_unwrap(v) for v in shape] + stride_vals = [_unwrap(v) for v in strides] + kwargs = {} + layout_attr = _resolve_layout_attr(layout) + if layout_attr is not None: + kwargs["layout"] = layout_attr + return _pto.MakeTensorViewOp( + tensor_type, _unwrap(ptr), shape_vals, stride_vals, **kwargs + ).result + + +def slice_view(subtensor_type, *, source, offsets, sizes): + offset_vals = [_unwrap(v) for v in offsets] + size_vals = [_unwrap(v) for v in sizes] + return _pto.PartitionViewOp( + subtensor_type, source, offsets=offset_vals, sizes=size_vals + ).result + + +@contextmanager +def vector_section(): + section = _pto.SectionVectorOp() + block = section.body.blocks.append() + with InsertionPoint(block): + yield + + +@contextmanager +def cube_section(): + section = _pto.SectionCubeOp() + block = section.body.blocks.append() + with InsertionPoint(block): + yield + + +def alloc_tile(tile_type, *, addr=None, valid_row=None, valid_col=None): + kwargs = {} + if addr is not None: + kwargs["addr"] = _unwrap(addr) + if valid_row is not None: + kwargs["valid_row"] = _unwrap(valid_row) + if valid_col is not None: + kwargs["valid_col"] = _unwrap(valid_col) + return _pto.AllocTileOp(tile_type, **kwargs).result + + +def load(source, dest): + _pto.TLoadOp(None, source, dest) + + +def store(source, dest): + _pto.TStoreOp(None, source, dest) + + +def print(format, scalar): + """ + Example: + `print("hello %d\n", const(5))` + is equivalent to + `cce::printf("hello%d\n", 5);` + + NOTE: may not print if the print buffer is full from previous + prints (typical when printing big tiles). + """ + if isinstance(scalar, Value): + scalar = _unwrap(scalar) + + _pto.print_(format, scalar) + + +__all__ = [ + "get_block_idx", + "get_subblock_idx", + "get_subblock_num", + "get_block_num", + "as_tensor", + "slice_view", + "vector_section", + "cube_section", + "alloc_tile", + "load", + "store", + "print", +] diff --git a/ptodsl/api/scalar.py b/ptodsl/api/scalar.py new file mode 100644 index 00000000..7f4e9d54 --- /dev/null +++ b/ptodsl/api/scalar.py @@ -0,0 +1,168 @@ +from mlir.dialects import arith +from mlir.ir import F16Type, F32Type, IndexType, IntegerType + + +def _unwrap(value): + if isinstance(value, Value): + return value.raw + return value + + +class Value: + # TODO: generalize to more comprehensive wrappers like + # https://github.com/makslevental/mlir-python-extras/blob/0.0.8.2/mlir/extras/dialects/ext/arith.py + def __init__(self, raw): + self.raw = raw + + def __mul__(self, other): + return Value(arith.MulIOp(_unwrap(self), _unwrap(other)).result) + + def __rmul__(self, other): + return Value(arith.MulIOp(_unwrap(other), _unwrap(self)).result) + + def __add__(self, other): + return Value(arith.AddIOp(_unwrap(self), _unwrap(other)).result) + + def __radd__(self, other): + return Value(arith.AddIOp(_unwrap(other), _unwrap(self)).result) + + def __sub__(self, other): + return Value(arith.SubIOp(_unwrap(self), _unwrap(other)).result) + + def __rsub__(self, other): + return Value(arith.SubIOp(_unwrap(other), _unwrap(self)).result) + + def __floordiv__(self, other): + return Value(arith.DivSIOp(_unwrap(self), _unwrap(other)).result) + + def __rfloordiv__(self, other): + return Value(arith.DivSIOp(_unwrap(other), _unwrap(self)).result) + + def __truediv__(self, other): + return Value(arith.DivFOp(_unwrap(self), _unwrap(other)).result) + + def __rtruediv__(self, other): + return Value(arith.DivFOp(_unwrap(other), _unwrap(self)).result) + + def __mod__(self, other): + return Value(arith.RemSIOp(_unwrap(self), _unwrap(other)).result) + + def __rmod__(self, other): + return Value(arith.RemSIOp(_unwrap(other), _unwrap(self)).result) + + @staticmethod + def _cmp(lhs, rhs, predicate): + return Value(arith.CmpIOp(predicate, _unwrap(lhs), _unwrap(rhs)).result) + + def __lt__(self, other): + return Value._cmp(self, other, arith.CmpIPredicate.slt) + + def __gt__(self, other): + return Value._cmp(self, other, arith.CmpIPredicate.sgt) + + def __le__(self, other): + return Value._cmp(self, other, arith.CmpIPredicate.sle) + + def __ge__(self, other): + return Value._cmp(self, other, arith.CmpIPredicate.sge) + + def __eq__(self, other): + return Value._cmp(self, other, arith.CmpIPredicate.eq) + + def __ne__(self, other): + return Value._cmp(self, other, arith.CmpIPredicate.ne) + + def __getattr__(self, item): + return getattr(self.raw, item) + + +def wrap_value(value): + if isinstance(value, Value): + return value + return Value(value) + + +def __getattr__(name): + # TODO: add more builtin dtype aliases (for example float16/bfloat16/int8/int64) + # when they are validated against PTO type support. + if name == "bool": + return IntegerType.get_signless(1) + if name == "float32": + return F32Type.get() + if name == "float16": + return F16Type.get() + if name == "int32": + return IntegerType.get_signless(32) + if name == "int16": + return IntegerType.get_signless(16) + if name == "uint32": + return IntegerType.get_unsigned(32) + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +def const(value): + return Value(arith.ConstantOp(IndexType.get(), value).result) + + +def index_cast(value, index_type=IndexType): + if hasattr(index_type, "get"): + dst = index_type.get() + else: + dst = index_type + return Value(arith.IndexCastOp(dst, _unwrap(value)).result) + + +def ceil_div(a, b): + return Value(arith.CeilDivSIOp(_unwrap(a), _unwrap(b)).result) + + +def div_s(a, b): + return Value(arith.DivSIOp(_unwrap(a), _unwrap(b)).result) + + +def rem_s(a, b): + return Value(arith.RemSIOp(_unwrap(a), _unwrap(b)).result) + + +def min_u(a, b): + return Value(arith.MinUIOp(_unwrap(a), _unwrap(b)).result) + + +def eq(a, b): + return Value(arith.CmpIOp(arith.CmpIPredicate.eq, _unwrap(a), _unwrap(b)).result) + + +def lt(a, b): + return Value(arith.CmpIOp(arith.CmpIPredicate.slt, _unwrap(a), _unwrap(b)).result) + + +def gt(a, b): + return Value(arith.CmpIOp(arith.CmpIPredicate.sgt, _unwrap(a), _unwrap(b)).result) + + +def ge(a, b): + return Value(arith.CmpIOp(arith.CmpIPredicate.sge, _unwrap(a), _unwrap(b)).result) + + +def select(cond, true_val, false_val): + return Value( + arith.SelectOp(_unwrap(cond), _unwrap(true_val), _unwrap(false_val)).result + ) + + +__all__ = [ + "Value", + "_unwrap", + "wrap_value", + "const", + "index_cast", + "ceil_div", + "div_s", + "rem_s", + "min_u", + "eq", + "lt", + "gt", + "ge", + "select", +] diff --git a/ptodsl/api/synchronization.py b/ptodsl/api/synchronization.py new file mode 100644 index 00000000..1a0801ee --- /dev/null +++ b/ptodsl/api/synchronization.py @@ -0,0 +1,70 @@ +from typing import Sequence + +from mlir.dialects import pto as _pto + + +def _resolve_sync_op(sync_op): + if isinstance(sync_op, str): + normalized = sync_op.strip().upper() + if not normalized.startswith("T"): + normalized = f"T{normalized}" + try: + return getattr(_pto, normalized) + except AttributeError as exc: + raise ValueError(f"Unsupported sync op type '{sync_op}'.") from exc + return sync_op + + +def _resolve_event_id(event_id): + if isinstance(event_id, int): + if event_id < 0 or event_id > 7: + raise ValueError(f"event_id must be in range [0, 7], got {event_id}.") + return getattr(_pto, f"EVENT_ID{event_id}") + return event_id + + +def record_event(record_op, wait_op, event_id: int | Sequence[int] = 0): + if not isinstance(event_id, int): + for eid in event_id: + _pto.record_event( + _resolve_sync_op(record_op), + _resolve_sync_op(wait_op), + _resolve_event_id(eid), + ) + else: + _pto.record_event( + _resolve_sync_op(record_op), + _resolve_sync_op(wait_op), + _resolve_event_id(event_id), + ) + + +def wait_event(record_op, wait_op, event_id: int | Sequence[int] = 0): + if not isinstance(event_id, int): + for eid in event_id: + _pto.wait_event( + _resolve_sync_op(record_op), + _resolve_sync_op(wait_op), + _resolve_event_id(eid), + ) + else: + _pto.wait_event( + _resolve_sync_op(record_op), + _resolve_sync_op(wait_op), + _resolve_event_id(event_id), + ) + + +def record_wait_pair(record_op, wait_op, event_id: int | Sequence[int] = 0): + record = _resolve_sync_op(record_op) + wait = _resolve_sync_op(wait_op) + event = _resolve_event_id(event_id) + _pto.record_event(record, wait, event) + _pto.wait_event(record, wait, event) + + +def barrier(sync_op): + _pto.barrier(_resolve_sync_op(sync_op)) + + +__all__ = ["record_event", "wait_event", "record_wait_pair", "barrier"] diff --git a/ptodsl/api/tile.py b/ptodsl/api/tile.py new file mode 100644 index 00000000..2cffe513 --- /dev/null +++ b/ptodsl/api/tile.py @@ -0,0 +1,204 @@ +from mlir.dialects import arith as _arith +from mlir.dialects import pto as _pto +from mlir.ir import BoolAttr, IntegerType + +from .scalar import _unwrap + + +def mov(source, dest): + _pto.TMovOp(None, source, dest) + + +def add(lhs, rhs, out): + _pto.TAddOp(lhs, rhs, out) + + +def sub(lhs, rhs, out): + _pto.TSubOp(lhs, rhs, out) + + +def div(lhs, rhs, out): + _pto.TDivOp(lhs, rhs, out) + + +def mul(lhs, rhs, out): + _pto.TMulOp(lhs, rhs, out) + + +def or_(lhs, rhs, out): + _pto.TOrOp(lhs, rhs, out) + + +def min(lhs, rhs, out): + _pto.TMinOp(lhs, rhs, out) + + +def max(lhs, rhs, out): + _pto.TMaxOp(lhs, rhs, out) + + +def gather(src, out, indices=None, *, mask_pattern=None): + if mask_pattern is not None: + mask = _pto.MaskPatternAttr.get(getattr(_pto.MaskPattern, mask_pattern)) + _pto.TGatherOp(src, out, maskPattern=mask) + else: + _pto.TGatherOp(src, out, indices=indices) + + +def exp(inp, out): + _pto.TExpOp(inp, out) + + +def log(inp, out): + _pto.TLogOp(inp, out) + + +def relu(inp, out): + _pto.TReluOp(inp, out) + + +def abs(inp, out): + _pto.TAbsOp(inp, out) + + +def sqrt(inp, out): + _pto.TSqrtOp(inp, out) + + +def rsqrt(inp, out): + _pto.TRsqrtOp(inp, out) + + +def reciprocal(inp, out): + _pto.TRecipOp(inp, out) + + +def matmul(lhs, rhs, out): + _pto.TMatmulOp(None, lhs, rhs, out) + + +def matmul_bias(lhs, rhs, bias, out): + _pto.TMatmulBiasOp(None, lhs, rhs, bias, out) + + +def matmul_acc(acc, lhs, rhs, out): + _pto.TMatmulAccOp(None, acc, lhs, rhs, out) + + +def extract(source, index_row, index_col, out): + _pto.TExtractOp( + src=source, indexRow=_unwrap(index_row), indexCol=_unwrap(index_col), dst=out + ) + + +def row_sum(src, tmp, dst): + _pto.TRowSumOp(src=src, tmp=tmp, dst=dst) + + +def row_min(src, tmp, dst): + _pto.TRowMinOp(src=src, tmp=tmp, dst=dst) + + +def row_max(src, tmp, dst): + _pto.TRowMaxOp(src=src, tmp=tmp, dst=dst) + + +def row_prod(src, tmp, dst): + _pto.TRowProdOp(src=src, tmp=tmp, dst=dst) + + +def row_expand(src, dst): + _pto.TRowExpandOp(src=src, dst=dst) + + +def row_expand_sub(src0, src1, dst): + _pto.TRowExpandSubOp(src0=src0, src1=src1, dst=dst) + + +def row_expand_div(src0, src1, dst): + _pto.TRowExpandDivOp(src0=src0, src1=src1, dst=dst) + + +def row_expand_mul(src0, src1, dst): + _pto.TRowExpandMulOp(src0=src0, src1=src1, dst=dst) + + +def col_sum(src, tmp, dst, is_binary=True): + _pto.TColSumOp(src=src, dst=dst, tmp=tmp, isBinary=BoolAttr.get(is_binary)) + + +def col_min(src, dst): + _pto.TColMinOp(src=src, dst=dst) + + +def col_max(src, dst): + _pto.TColMaxOp(src=src, dst=dst) + + +def col_prod(src, tmp, dst, is_binary=True): + _pto.TColProdOp(src=src, dst=dst, tmp=tmp, isBinary=BoolAttr.get(is_binary)) + + +def col_expand(src, dst): + _pto.TColExpandOp(src=src, dst=dst) + + +def mrgsort(src, dst, block_len): + i32 = IntegerType.get_signless(32) + block_len_i32 = _arith.IndexCastOp(i32, _unwrap(block_len)).result + _pto.TMrgSortOp(srcs=[src], dsts=[dst], blockLen=block_len_i32) + + +def sort32(src, dst, idx): + """TSORT32: sort src tile within 32-element blocks, writing interleaved + (score, index) pairs to dst. idx is an input tile of uint32 indices + attached to each src element. For float16 src, dst must have 4x the + columns of src (each element expands to 4 float16 words).""" + _pto.TSort32Op(src, dst, idx) + + +def subset(source, offsets, sizes): + offset_vals = [_unwrap(v) for v in offsets] + return _pto.subset(source, offset_vals, sizes) + + +def print(source): + _pto.tprint(source) + + +__all__ = [ + "mov", + "add", + "sub", + "div", + "mul", + "or_", + "gather", + "exp", + "log", + "relu", + "abs", + "sqrt", + "rsqrt", + "reciprocal", + "matmul", + "matmul_bias", + "matmul_acc", + "extract", + "row_sum", + "row_min", + "row_max", + "row_prod", + "row_expand", + "row_expand_sub", + "row_expand_div", + "row_expand_mul", + "col_sum", + "col_min", + "col_max", + "col_prod", + "col_expand", + "mrgsort", + "sort32", + "subset", +] diff --git a/ptodsl/api/type_def.py b/ptodsl/api/type_def.py new file mode 100644 index 00000000..251303f6 --- /dev/null +++ b/ptodsl/api/type_def.py @@ -0,0 +1,112 @@ +from mlir.dialects import pto as _pto + +from . import scalar + + +def __getattr__(name): + # MLIR type factories require an active context, so keep dtype aliases lazy + # and resolve them only when user code accesses them inside PTO/MLIR setup. + if name in {"bool", "float16", "float32", "int16", "int32", "uint32"}: + return getattr(scalar, name) + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + + +def PtrType(dtype): + return _pto.PtrType.get(dtype) + + +def TensorType(*, rank, dtype): + return _pto.TensorViewType.get(rank, dtype) + + +def SubTensorType(*, shape, dtype): + return _pto.PartitionTensorViewType.get(shape, dtype) + + +class TileBufConfig: + def __init__( + self, blayout="RowMajor", slayout="NoneBox", s_fractal_size=512, pad="Null" + ): + # TODO: expose and validate a broader set of tile buffer knobs if PTO adds + # more layout/padding/fractal settings that should be configurable here. + self._bl = _pto.BLayoutAttr.get(getattr(_pto.BLayout, blayout)) + self._sl = _pto.SLayoutAttr.get(getattr(_pto.SLayout, slayout)) + self._pd = _pto.PadValueAttr.get(getattr(_pto.PadValue, pad)) + self._s_fractal_size = s_fractal_size + + @property + def attr(self): + return _pto.TileBufConfigAttr.get( + self._bl, self._sl, self._s_fractal_size, self._pd + ) + + +def _default_tile_config(memory_space, shape): + space = memory_space.upper() + # Defaults mirror the explicit configs used by the verbose matmul builder. + if space == "MAT": + if len(shape) >= 1 and shape[0] == 1: + return TileBufConfig( + blayout="RowMajor", + slayout="NoneBox", + s_fractal_size=_pto.TileConfig.fractalABSize, + ) + return TileBufConfig( + blayout="ColMajor", + slayout="RowMajor", + s_fractal_size=_pto.TileConfig.fractalABSize, + ) + if space == "LEFT": + return TileBufConfig( + blayout="RowMajor", + slayout="RowMajor", + s_fractal_size=_pto.TileConfig.fractalABSize, + ) + if space == "RIGHT": + return TileBufConfig( + blayout="RowMajor", + slayout="ColMajor", + s_fractal_size=_pto.TileConfig.fractalABSize, + ) + if space == "ACC": + return TileBufConfig( + blayout="ColMajor", + slayout="RowMajor", + s_fractal_size=_pto.TileConfig.fractalCSize, + ) + if space == "BIAS": + return TileBufConfig( + blayout="RowMajor", + slayout="NoneBox", + s_fractal_size=_pto.TileConfig.fractalABSize, + ) + if space == "VEC": + return TileBufConfig() + raise ValueError( + f"Unsupported memory_space '{memory_space}' for default tile config." + ) + + +def TileBufType(*, shape, dtype, memory_space, valid_shape=None, config=None): + space = _pto.AddressSpaceAttr.get(getattr(_pto.AddressSpace, memory_space)) + if valid_shape is None: + valid_shape = shape + if config is None: + config = _default_tile_config(memory_space, shape) + cfg = config.attr if isinstance(config, TileBufConfig) else config + return _pto.TileBufType.get(shape, dtype, space, valid_shape, cfg) + + +__all__ = [ + "PtrType", + "TensorType", + "SubTensorType", + "TileBufConfig", + "TileBufType", + "bool", + "float16", + "float32", + "int16", + "int32", + "uint32", +] diff --git a/ptodsl/bench.py b/ptodsl/bench.py index b9efa909..b6ae4870 100644 --- a/ptodsl/bench.py +++ b/ptodsl/bench.py @@ -1,52 +1,3 @@ -from typing import Callable, List, Literal, Union +from .utils.bench import do_bench - -def do_bench( - fn: Callable, - warmup_iters: int = 5, - benchmark_iters: int = 15, - aggregation: Literal["mean", "none"] = "mean", - unit: Literal["s", "ms", "us", "ns"] = "us", - flush_cache: bool = True, -) -> Union[float, List[float]]: - """ - Benchmark a given function with warmup. - - Args: - fn: Function to benchmark. - warmup_iters: Number of warmup runs. - benchmark_iters: Number of benchmark runs. - aggregation: Aggregation mode for benchmark times. - unit: Time unit of the benchmarks. - flush_cache: if we should overwrite l2 cache between every iteration - Returns: - Runtime, or list of runtimes, in specified units. - """ - import torch - import torch_npu - start_events = [torch.npu.Event(enable_timing=True) for _ in range(benchmark_iters)] - end_events = [torch.npu.Event(enable_timing=True) for _ in range(benchmark_iters)] - - # Allocate a 256 MB tensor which we write to every iteration to flush L2 cache - # https://github.com/tile-ai/tilelang/blob/main/tilelang/profiler/bench.py#L103 - cache_size = 256 * 1024 * 1024 - cache = torch.empty((cache_size), dtype=torch.int8).npu() - - for _ in range(warmup_iters): - fn() - torch_npu.npu.synchronize() - - for i in range(benchmark_iters): - if flush_cache: - cache.zero_() - torch_npu.npu.synchronize() - start_events[i].record() - fn() - end_events[i].record() - torch_npu.npu.synchronize() - - f = {"s": 1e-3, "ms": 1e0, "us": 1e3, "ns": 1e6}[unit] - times = [f * s.elapsed_time(e) for s, e in zip(start_events, end_events)] - if aggregation == "mean": - return sum(times) / len(times) - return times +__all__ = ["do_bench"] diff --git a/ptodsl/compiler/__init__.py b/ptodsl/compiler/__init__.py new file mode 100644 index 00000000..1e6d9312 --- /dev/null +++ b/ptodsl/compiler/__init__.py @@ -0,0 +1,4 @@ +from .ir import to_ir_module +from .jit import JitWrapper, jit + +__all__ = ["JitWrapper", "jit", "to_ir_module"] diff --git a/ptodsl/compiler/ir.py b/ptodsl/compiler/ir.py new file mode 100644 index 00000000..488d8638 --- /dev/null +++ b/ptodsl/compiler/ir.py @@ -0,0 +1,122 @@ +import inspect + +from mlir.dialects import func, pto as _pto +from mlir.ir import Context, InsertionPoint, Location, Module + +from ..api.scalar import wrap_value +from ..constexpr import is_constexpr_annotation + + +def _resolve_meta(meta_fn): + values = meta_fn() + if not isinstance(values, dict): + raise ValueError( + "`meta_data()` must return a dict of named symbols to MLIR/PTO types." + ) + return dict(values) + + +def _resolve_arg_types(signature, meta_map): + arg_types = [] + for param in signature.parameters.values(): + annot = param.annotation + if is_constexpr_annotation(annot): + continue + if isinstance(annot, str): + if annot not in meta_map: + raise ValueError(f"Unknown annotation '{annot}'.") + arg_types.append(meta_map[annot]) + elif annot is inspect._empty: + raise ValueError(f"Missing annotation for argument '{param.name}'.") + else: + arg_types.append(annot) + return arg_types + + +def _resolve_ret_types(signature, meta_map): + ret_annot = signature.return_annotation + if ret_annot in (inspect._empty, None): + return [] + if isinstance(ret_annot, str): + if ret_annot not in meta_map: + raise ValueError(f"Unknown return annotation '{ret_annot}'.") + return [meta_map[ret_annot]] + if isinstance(ret_annot, (list, tuple)): + out = [] + for elem in ret_annot: + if isinstance(elem, str): + out.append(meta_map[elem]) + else: + out.append(elem) + return out + return [ret_annot] + + +def _has_func_return(block): + last_name = None + for op in block.operations: + last_name = op.operation.name + return last_name == "func.return" + + +def _inject_globals(fn, values): + old = {} + for name, value in values.items(): + old[name] = fn.__globals__.get(name, None) + fn.__globals__[name] = value + return old + + +def _restore_globals(fn, old, injected_names): + for name in injected_names: + if old[name] is None and name in fn.__globals__: + del fn.__globals__[name] + else: + fn.__globals__[name] = old[name] + + +def to_ir_module(*, meta_data): + def decorator(fn): + sig = inspect.signature(fn) + + with Context() as ctx, Location.unknown(): + _pto.register_dialect(ctx, load=True) + meta_map = _resolve_meta(meta_data) + arg_types = _resolve_arg_types(sig, meta_map) + ret_types = _resolve_ret_types(sig, meta_map) + module = Module.create() + fn_ty = func.FunctionType.get(arg_types, ret_types) + + with InsertionPoint(module.body): + ir_func = func.FuncOp(fn.__name__, fn_ty) + entry = ir_func.add_entry_block() + + with InsertionPoint(entry): + wrapped_args = [] + entry_arg_iter = iter(entry.arguments) + for param in sig.parameters.values(): + if is_constexpr_annotation(param.annotation): + if param.default is inspect._empty: + raise ValueError( + f"Constexpr argument '{param.name}' requires a default value." + ) + wrapped_args.append(param.default) + else: + wrapped_args.append(wrap_value(next(entry_arg_iter))) + injected = set(meta_map.keys()) + old_globals = _inject_globals(fn, meta_map) + try: + fn(*wrapped_args) + finally: + _restore_globals(fn, old_globals, injected) + + if not ret_types and not _has_func_return(entry): + func.ReturnOp([]) + + module.operation.verify() + return module + + return decorator + + +__all__ = ["to_ir_module"] diff --git a/ptodsl/compiler/jit.py b/ptodsl/compiler/jit.py new file mode 100644 index 00000000..9e39ebeb --- /dev/null +++ b/ptodsl/compiler/jit.py @@ -0,0 +1,312 @@ +import ctypes +import inspect +import os +import pathlib +import subprocess +from functools import update_wrapper + +from mlir.dialects import pto as _pto +from mlir.ir import Context, Location + +from .ir import to_ir_module + + +def _type_repr(type_obj): + return str(type_obj).replace(" ", "").lower() + + +def _is_ptr_type(type_obj): + return "ptr" in _type_repr(type_obj) + + +def _ptr_elem_cpp_type(type_obj): + type_repr = _type_repr(type_obj) + if "e8m0" in type_repr: + return "float8_e8m0_t" + if "e4m3" in type_repr: + return "float8_e4m3_t" + if "e5m2" in type_repr: + return "float8_e5m2_t" + if "f32" in type_repr: + return "float" + if "f16" in type_repr: + return "__fp16" + if "bf16" in type_repr: + return "__bf16" + if "i8" in type_repr: + return "int8_t" + if "u8" in type_repr: + return "uint8_t" + if "i16" in type_repr: + return "int16_t" + if "u16" in type_repr: + return "uint16_t" + if "i32" in type_repr: + return "int32_t" + if "u32" in type_repr: + return "uint32_t" + if "i64" in type_repr: + return "int64_t" + if "u64" in type_repr: + return "uint64_t" + return "float" + + +def _scalar_cpp_type(type_obj): + type_repr = _type_repr(type_obj) + if "i32" in type_repr: + return "int32_t" + if "i64" in type_repr or "index" in type_repr: + return "int64_t" + if "e8m0" in type_repr or "e4m3" in type_repr or "e5m2" in type_repr: + return "uint8_t" + if "f32" in type_repr: + return "float" + if "f16" in type_repr: + return "__fp16" + return "int32_t" + + +def _scalar_ctype(type_obj): + type_repr = _type_repr(type_obj) + if "i64" in type_repr or "index" in type_repr: + return ctypes.c_int64 + if "e8m0" in type_repr or "e4m3" in type_repr or "e5m2" in type_repr: + return ctypes.c_uint8 + if "f32" in type_repr: + return ctypes.c_float + if "f16" in type_repr: + return ctypes.c_uint16 + return ctypes.c_int32 + + +def _normalize_stream_ptr(stream_ptr): + if isinstance(stream_ptr, ctypes.c_void_p): + return stream_ptr + if isinstance(stream_ptr, int): + return ctypes.c_void_p(stream_ptr) + if hasattr(stream_ptr, "value"): + return ctypes.c_void_p(int(stream_ptr.value)) + return stream_ptr + + +class JitWrapper: + def __init__( + self, + fn, + *, + meta_data, + output_dir=None, + block_dim=20, + enable_insert_sync=True, + npu_arch="dav-2201", + ): + self._fn = fn + self._meta_data = meta_data + self._sig = inspect.signature(fn) + self._arg_types = None + self._output_dir = ( + pathlib.Path(output_dir) + if output_dir + else pathlib.Path.cwd() / ".ptodsl_jit" / fn.__name__ + ) + self._block_dim = block_dim + self._enable_insert_sync = enable_insert_sync + self._npu_arch = npu_arch + self._compiled = False + self._lib = None + self._lib_path = self._output_dir / "kernel.so" + update_wrapper(self, fn) + + def _artifact_paths(self): + pto_path = self._output_dir / "kernel.pto" + cpp_path = self._output_dir / "kernel.cpp" + caller_path = self._output_dir / "caller.cpp" + return pto_path, cpp_path, caller_path, self._lib_path + + def _generate_caller_cpp(self, kernel_cpp_name): + params = list(self._sig.parameters.values()) + cpp_args = [] + launch_args = [] + for param, arg_type in zip(params, self._arg_types): + if _is_ptr_type(arg_type): + cpp_args.append(f"uint8_t *{param.name}") + launch_args.append(f"({_ptr_elem_cpp_type(arg_type)} *){param.name}") + else: + cpp_t = _scalar_cpp_type(arg_type) + cpp_args.append(f"{cpp_t} {param.name}") + launch_args.append(param.name) + + wrapper_sig = ", ".join(["uint32_t blockDim", "void *stream"] + cpp_args) + kernel_call = ", ".join(launch_args) + return ( + f'#include "{kernel_cpp_name}"\n' + f"#include \n\n" + f'extern "C" void call_kernel({wrapper_sig})\n' + "{\n" + f" {self._fn.__name__}<<>>({kernel_call});\n" + "}\n" + ) + + def _compile_shared_library(self, caller_cpp_path, lib_path): + toolkit_home = os.environ.get("ASCEND_TOOLKIT_HOME") + if not toolkit_home: + raise RuntimeError( + "ASCEND_TOOLKIT_HOME is required to compile generated caller.cpp." + ) + cmd = [ + "bisheng", + f"-I{toolkit_home}/include", + "-fPIC", + "-shared", + "-D_FORTIFY_SOURCE=2", + "-O2", + "-std=c++17", + "-Wno-macro-redefined", + "-Wno-ignored-attributes", + "-fstack-protector-strong", + "-xcce", + "-Xhost-start", + "-Xhost-end", + "-mllvm", + "-cce-aicore-stack-size=0x8000", + "-mllvm", + "-cce-aicore-function-stack-size=0x8000", + "-mllvm", + "-cce-aicore-record-overflow=true", + "-mllvm", + "-cce-aicore-addr-transform", + "-mllvm", + "-cce-aicore-dcci-insert-for-scalar=false", + f"--npu-arch={self._npu_arch}", + "-DMEMORY_BASE", # TODO: add switch for A5 + "-std=gnu++17", + str(caller_cpp_path), + "-o", + str(lib_path), + ] + subprocess.run(cmd, check=True, cwd=str(self._output_dir)) + + def _resolve_runtime_arg_types(self): + from .ir import _resolve_arg_types, _resolve_meta + + with Context() as ctx, Location.unknown(): + _pto.register_dialect(ctx, load=True) + meta_map = _resolve_meta(self._meta_data) + return _resolve_arg_types(self._sig, meta_map) + + def _build(self): + self._output_dir.mkdir(parents=True, exist_ok=True) + pto_path, cpp_path, caller_path, lib_path = self._artifact_paths() + self._arg_types = self._resolve_runtime_arg_types() + + ir_module = to_ir_module(meta_data=self._meta_data)(self._fn) + pto_path.write_text(f"{ir_module}\n", encoding="utf-8") + + ptoas_cmd = ["ptoas"] + if self._enable_insert_sync: + ptoas_cmd.append("--enable-insert-sync") + ptoas_cmd += [str(pto_path), "-o", str(cpp_path)] + subprocess.run(ptoas_cmd, check=True, cwd=str(self._output_dir)) + + caller_path.write_text( + self._generate_caller_cpp(cpp_path.name), encoding="utf-8" + ) + self._compile_shared_library(caller_path, lib_path) + + self._lib = ctypes.CDLL(str(lib_path)) + self._lib.call_kernel.argtypes = [ctypes.c_uint32, ctypes.c_void_p] + [ + ctypes.c_void_p if _is_ptr_type(arg_type) else _scalar_ctype(arg_type) + for arg_type in self._arg_types + ] + self._compiled = True + + def _convert_ptr(self, value): + if isinstance(value, ctypes.c_void_p): + return value + if hasattr(value, "data_ptr"): + return ctypes.c_void_p(value.data_ptr()) + if isinstance(value, int): + return ctypes.c_void_p(value) + raise TypeError(f"Pointer-like argument expected, got {type(value)!r}.") + + def _prepare_call_args(self, args): + params = list(self._sig.parameters.values()) + if len(args) > len(params): + raise TypeError( + f"Expected at most {len(params)} arguments, got {len(args)}." + ) + + filled_args = list(args) + for idx in range(len(args), len(params)): + param = params[idx] + if param.default is not inspect._empty: + filled_args.append(param.default) + continue + arg_type = self._arg_types[idx] + if _is_ptr_type(arg_type): + raise TypeError(f"Missing required pointer argument '{param.name}'.") + + converted = [] + for value, arg_type in zip(filled_args, self._arg_types): + if _is_ptr_type(arg_type): + converted.append(self._convert_ptr(value)) + else: + converted.append(value) + return converted + + # TODO: also allow taking named `kwargs` + def __call__(self, *args, stream_ptr=None): + if not self._compiled: + self._build() + + if stream_ptr is None: + import torch + + stream_ptr = torch.npu.current_stream()._as_parameter_ + + call_args = self._prepare_call_args(args) + self._lib.call_kernel( + ctypes.c_uint32(self._block_dim), + _normalize_stream_ptr(stream_ptr), + *call_args, + ) + return None + + def set_block_dim(self, block_dim): + if not isinstance(block_dim, int) or block_dim <= 0: + raise ValueError("`block_dim` must be a positive integer.") + self._block_dim = block_dim + return self + + @property + def library_path(self): + return str(self._lib_path) + + @property + def output_dir(self): + return str(self._output_dir) + + +def jit( + *, + meta_data, + output_dir=None, + block_dim=1, + enable_insert_sync=True, + npu_arch="dav-2201", +): + def decorator(fn): + return JitWrapper( + fn, + meta_data=meta_data, + output_dir=output_dir, + block_dim=block_dim, + enable_insert_sync=enable_insert_sync, + npu_arch=npu_arch, + ) + + return decorator + + +__all__ = ["JitWrapper", "jit"] diff --git a/ptodsl/constexpr.py b/ptodsl/constexpr.py new file mode 100644 index 00000000..a1475b9a --- /dev/null +++ b/ptodsl/constexpr.py @@ -0,0 +1,37 @@ +import builtins + + +class ConstexprAnnotation: + __ptodsl_constexpr__ = True + + def __init__(self, inner_type): + self.inner_type = inner_type + + def __repr__(self): + return f"Constexpr[{self.inner_type!r}]" + + +class Constexpr: + def __class_getitem__(cls, inner_type): + return ConstexprAnnotation(inner_type) + + +def is_constexpr_annotation(annotation): + return getattr(annotation, "__ptodsl_constexpr__", False) + + +def const_expr(value): + return value + + +def range_constexpr(*args): + return builtins.range(*args) + + +__all__ = [ + "Constexpr", + "ConstexprAnnotation", + "const_expr", + "is_constexpr_annotation", + "range_constexpr", +] diff --git a/ptodsl/language.py b/ptodsl/language.py index d465ddd4..de97b725 100644 --- a/ptodsl/language.py +++ b/ptodsl/language.py @@ -1,7 +1,10 @@ from contextlib import contextmanager +from dataclasses import dataclass +from typing import Sequence +from mlir import ir as mlir_ir from mlir.dialects import arith, pto, scf -from mlir.ir import F16Type, F32Type, IndexType, InsertionPoint, IntegerType +from mlir.ir import IndexType, InsertionPoint, IntegerType def _unwrap(value): @@ -83,15 +86,82 @@ def wrap_value(value): return Value(value) +@dataclass(frozen=True) +class MXFP8DType: + lhs: object + rhs: object + scale: object + acc: object + scale_factor: int = 32 + + @property + def data(self): + return self.lhs + + def scale_k(self, k): + if k % self.scale_factor != 0: + raise ValueError( + f"k={k} must be divisible by scale_factor={self.scale_factor} for MXFP8." + ) + return k // self.scale_factor + + +def _get_mlir_float_type(alias_name, *type_names): + for type_name in type_names: + type_ctor = getattr(mlir_ir, type_name, None) + if type_ctor is not None: + return type_ctor.get() + supported = ", ".join(type_names) + raise AttributeError( + f"module '{__name__}' has no attribute '{alias_name}' because the active MLIR " + f"Python bindings do not expose any of: {supported}" + ) + + +def make_mxfp8(*, lhs="e5m2", rhs="e5m2", acc=None, scale_factor=32): + variants = { + "e4m3": __getattr__("fp8_e4m3"), + "e5m2": __getattr__("fp8_e5m2"), + } + if lhs not in variants: + raise ValueError( + f"Unsupported lhs variant '{lhs}'. Expected one of: {', '.join(sorted(variants))}." + ) + if rhs not in variants: + raise ValueError( + f"Unsupported rhs variant '{rhs}'. Expected one of: {', '.join(sorted(variants))}." + ) + return MXFP8DType( + lhs=variants[lhs], + rhs=variants[rhs], + scale=__getattr__("fp8_e8m0"), + acc=__getattr__("float32") if acc is None else acc, + scale_factor=scale_factor, + ) + + def __getattr__(name): - # TODO: add more builtin dtype aliases (for example float16/bfloat16/int8/int64) - # when they are validated against PTO type support. + # Keep aliases conservative and only expose types that map cleanly to MLIR/PTO. if name == "bool": return IntegerType.get_signless(1) if name == "float32": - return F32Type.get() + return _get_mlir_float_type(name, "F32Type", "Float32Type") if name == "float16": - return F16Type.get() + return _get_mlir_float_type(name, "F16Type", "Float16Type") + if name == "bfloat16": + return _get_mlir_float_type(name, "BF16Type") + if name in ("fp8_e4m3", "float8_e4m3"): + return _get_mlir_float_type(name, "Float8E4M3FNType", "Float8E4M3FNUZType") + if name in ("fp8_e5m2", "float8_e5m2"): + return _get_mlir_float_type(name, "Float8E5M2Type", "Float8E5M2FNUZType") + if name in ("fp8_e8m0", "float8_e8m0"): + return _get_mlir_float_type(name, "Float8E8M0FNUType", "Float8E8M0FNType") + if name == "mxfp8": + return make_mxfp8(lhs="e5m2", rhs="e5m2") + if name == "mxfp8_e4m3": + return make_mxfp8(lhs="e4m3", rhs="e4m3") + if name == "mxfp8_e5m2": + return make_mxfp8(lhs="e5m2", rhs="e5m2") if name == "int32": return IntegerType.get_signless(32) if name == "int16": @@ -112,7 +182,9 @@ def SubTensorType(*, shape, dtype): class TileBufConfig: - def __init__(self, blayout="RowMajor", slayout="NoneBox", s_fractal_size=512, pad="Null"): + def __init__( + self, blayout="RowMajor", slayout="NoneBox", s_fractal_size=512, pad="Null" + ): # TODO: expose and validate a broader set of tile buffer knobs if PTO adds # more layout/padding/fractal settings that should be configurable here. self._bl = pto.BLayoutAttr.get(getattr(pto.BLayout, blayout)) @@ -122,7 +194,9 @@ def __init__(self, blayout="RowMajor", slayout="NoneBox", s_fractal_size=512, pa @property def attr(self): - return pto.TileBufConfigAttr.get(self._bl, self._sl, self._s_fractal_size, self._pd) + return pto.TileBufConfigAttr.get( + self._bl, self._sl, self._s_fractal_size, self._pd + ) def _default_tile_config(memory_space, shape): @@ -130,19 +204,51 @@ def _default_tile_config(memory_space, shape): # Defaults mirror the explicit configs used by the verbose matmul builder. if space == "MAT": if len(shape) >= 1 and shape[0] == 1: - return TileBufConfig(blayout="RowMajor", slayout="NoneBox", s_fractal_size=pto.TileConfig.fractalABSize) - return TileBufConfig(blayout="ColMajor", slayout="RowMajor", s_fractal_size=pto.TileConfig.fractalABSize) + return TileBufConfig( + blayout="RowMajor", + slayout="NoneBox", + s_fractal_size=pto.TileConfig.fractalABSize, + ) + return TileBufConfig( + blayout="ColMajor", + slayout="RowMajor", + s_fractal_size=pto.TileConfig.fractalABSize, + ) if space == "LEFT": - return TileBufConfig(blayout="RowMajor", slayout="RowMajor", s_fractal_size=pto.TileConfig.fractalABSize) + return TileBufConfig( + blayout="RowMajor", + slayout="RowMajor", + s_fractal_size=pto.TileConfig.fractalABSize, + ) if space == "RIGHT": - return TileBufConfig(blayout="RowMajor", slayout="ColMajor", s_fractal_size=pto.TileConfig.fractalABSize) + return TileBufConfig( + blayout="RowMajor", + slayout="ColMajor", + s_fractal_size=pto.TileConfig.fractalABSize, + ) if space == "ACC": - return TileBufConfig(blayout="ColMajor", slayout="RowMajor", s_fractal_size=pto.TileConfig.fractalCSize) + return TileBufConfig( + blayout="ColMajor", + slayout="RowMajor", + s_fractal_size=pto.TileConfig.fractalCSize, + ) if space == "BIAS": - return TileBufConfig(blayout="RowMajor", slayout="NoneBox", s_fractal_size=pto.TileConfig.fractalABSize) + return TileBufConfig( + blayout="RowMajor", + slayout="NoneBox", + s_fractal_size=pto.TileConfig.fractalABSize, + ) + if space == "SCALING": + return TileBufConfig( + blayout="RowMajor", + slayout="NoneBox", + s_fractal_size=pto.TileConfig.fractalABSize, + ) if space == "VEC": return TileBufConfig() - raise ValueError(f"Unsupported memory_space '{memory_space}' for default tile config.") + raise ValueError( + f"Unsupported memory_space '{memory_space}' for default tile config." + ) def TileBufType(*, shape, dtype, memory_space, valid_shape=None, config=None): @@ -155,6 +261,38 @@ def TileBufType(*, shape, dtype, memory_space, valid_shape=None, config=None): return pto.TileBufType.get(shape, dtype, space, valid_shape, cfg) +def LeftScaleTileBufType(*, shape, dtype, valid_shape=None, config=None): + if config is None: + config = TileBufConfig( + blayout="RowMajor", + slayout="RowMajor", + s_fractal_size=pto.TileConfig.fractalMxSize, + ) + return TileBufType( + shape=shape, + dtype=dtype, + memory_space="SCALING", + valid_shape=valid_shape, + config=config, + ) + + +def RightScaleTileBufType(*, shape, dtype, valid_shape=None, config=None): + if config is None: + config = TileBufConfig( + blayout="ColMajor", + slayout="ColMajor", + s_fractal_size=pto.TileConfig.fractalMxSize, + ) + return TileBufType( + shape=shape, + dtype=dtype, + memory_space="SCALING", + valid_shape=valid_shape, + config=config, + ) + + def const(value): return Value(arith.ConstantOp(IndexType.get(), value).result) @@ -186,13 +324,17 @@ def index_cast(value, index_type=IndexType): def as_tensor(tensor_type, *, ptr, shape, strides): shape_vals = [_unwrap(v) for v in shape] stride_vals = [_unwrap(v) for v in strides] - return pto.MakeTensorViewOp(tensor_type, _unwrap(ptr), shape_vals, stride_vals).result + return pto.MakeTensorViewOp( + tensor_type, _unwrap(ptr), shape_vals, stride_vals + ).result def slice_view(subtensor_type, *, source, offsets, sizes): offset_vals = [_unwrap(v) for v in offsets] size_vals = [_unwrap(v) for v in sizes] - return pto.PartitionViewOp(subtensor_type, source, offsets=offset_vals, sizes=size_vals).result + return pto.PartitionViewOp( + subtensor_type, source, offsets=offset_vals, sizes=size_vals + ).result @contextmanager @@ -227,6 +369,11 @@ def alloc_tile(tile_type, *, valid_row=None, valid_col=None): return pto.AllocTileOp(tile_type, **kwargs).result +def subset(source, offsets, sizes): + offset_vals = [_unwrap(v) for v in offsets] + return pto.subset(source, offset_vals, sizes) + + def load(source, dest): pto.TLoadOp(None, source, dest) @@ -299,6 +446,30 @@ def matmul_acc(acc, lhs, rhs, out): pto.TMatmulAccOp(None, acc, lhs, rhs, out) +def _emit_dps_op(op_name, *operands): + op_ctor = getattr(pto, op_name, None) + if op_ctor is not None: + return op_ctor(None, *operands) + generic_name = { + "TMatmulMxOp": "pto.tmatmul.mx", + "TMatmulMxAccOp": "pto.tmatmul.mx.acc", + "TMatmulMxBiasOp": "pto.tmatmul.mx.bias", + }[op_name] + return mlir_ir.Operation.create(generic_name, operands=list(operands)) + + +def matmul_mx(lhs, lhs_scale, rhs, rhs_scale, out): + _emit_dps_op("TMatmulMxOp", lhs, lhs_scale, rhs, rhs_scale, out) + + +def matmul_mx_acc(acc, lhs, lhs_scale, rhs, rhs_scale, out): + _emit_dps_op("TMatmulMxAccOp", acc, lhs, lhs_scale, rhs, rhs_scale, out) + + +def matmul_mx_bias(lhs, lhs_scale, rhs, rhs_scale, bias, out): + _emit_dps_op("TMatmulMxBiasOp", lhs, lhs_scale, rhs, rhs_scale, bias, out) + + def ceil_div(a, b): return Value(arith.CeilDivSIOp(_unwrap(a), _unwrap(b)).result) @@ -332,18 +503,22 @@ def ge(a, b): def select(cond, true_val, false_val): - return Value(arith.SelectOp(_unwrap(cond), _unwrap(true_val), _unwrap(false_val)).result) + return Value( + arith.SelectOp(_unwrap(cond), _unwrap(true_val), _unwrap(false_val)).result + ) class _IfElseBranch: def __init__(self, if_op): self._if_op = if_op + @contextmanager def else_context(self): with InsertionPoint(self._if_op.else_block): yield scf.YieldOp([]) + @contextmanager def if_context(condition, has_else=False): if has_else: @@ -368,6 +543,7 @@ def cond(condition, then_builder, else_builder): scf.YieldOp([]) return op + def _resolve_sync_op(sync_op): if isinstance(sync_op, str): normalized = sync_op.strip().upper() @@ -388,17 +564,49 @@ def _resolve_event_id(event_id): return event_id -def record_event(record_op, wait_op, event_id=0): - pto.record_event(_resolve_sync_op(record_op), _resolve_sync_op(wait_op), _resolve_event_id(event_id)) - - -def wait_event(record_op, wait_op, event_id=0): - pto.wait_event(_resolve_sync_op(record_op), _resolve_sync_op(wait_op), _resolve_event_id(event_id)) +def record_event(record_op, wait_op, event_id: int | Sequence[int] = 0): + if not isinstance(event_id, int): + for eid in event_id: + pto.record_event( + _resolve_sync_op(record_op), + _resolve_sync_op(wait_op), + _resolve_event_id(eid), + ) + else: + pto.record_event( + _resolve_sync_op(record_op), + _resolve_sync_op(wait_op), + _resolve_event_id(event_id), + ) + + +def wait_event(record_op, wait_op, event_id: int | Sequence[int] = 0): + if not isinstance(event_id, int): + for eid in event_id: + pto.wait_event( + _resolve_sync_op(record_op), + _resolve_sync_op(wait_op), + _resolve_event_id(eid), + ) + else: + pto.wait_event( + _resolve_sync_op(record_op), + _resolve_sync_op(wait_op), + _resolve_event_id(event_id), + ) -def record_wait_pair(record_op, wait_op, event_id=0): +def record_wait_pair(record_op, wait_op, event_id: int | Sequence[int] = 0): rec = _resolve_sync_op(record_op) w = _resolve_sync_op(wait_op) ev = _resolve_event_id(event_id) pto.record_event(rec, w, ev) pto.wait_event(rec, w, ev) + + +def barrier(sync_op): + pto.barrier(_resolve_sync_op(sync_op)) + + +def row_sum(src, tmp, dst): + pto.TRowSumOp(src=src, tmp=tmp, dst=dst) diff --git a/ptodsl/lib/__init__.py b/ptodsl/lib/__init__.py new file mode 100644 index 00000000..c8d47101 --- /dev/null +++ b/ptodsl/lib/__init__.py @@ -0,0 +1,3 @@ +from . import a5 + +__all__ = ["a5"] diff --git a/ptodsl/lib/a5/README.md b/ptodsl/lib/a5/README.md new file mode 100644 index 00000000..baa2c9ee --- /dev/null +++ b/ptodsl/lib/a5/README.md @@ -0,0 +1,27 @@ +# A5 Library Layer + +This directory contains a first PTODSL library-style translation layer for the +`pto-isa/include/pto/npu/a5` surface. + +The scope of this pass is: + +- Pythonic wrappers over PTO tile ops and selected micro instructions +- A5-flavored compatibility aliases such as `TLoad`, `TAdd`, `TMatmul`, and `TStore` +- Translated builder kernels that emit `.pto` through PTODSL +- A checked-in generation flow for reproducible `.pto` artifacts + +Entry points: + +- [`ops.py`](./ops.py): reusable A5-style helpers built on PTODSL and PTO dialect ops +- [`kernels.py`](./kernels.py): translated example kernels +- [`generated`](./generated): emitted `.pto` artifacts from `scripts/generate_a5_pto.py` + +Regenerate the current artifacts with: + +```bash +PYTHONPATH=/Users/zhoubot/github/.llvm-19.1.7/build-mlir-py312/tools/mlir/python_packages/mlir_core:/Users/zhoubot/github/pto-org/PTOAS/install-src312:/Users/zhoubot/github/pto-org/PTOAS/build-src312/python \ +/Users/zhoubot/github/.venv-ptoas-src312/bin/python scripts/generate_a5_pto.py +``` + +`--emit-cpp` is best-effort: the tile-based kernels lower through local `ptoas`, +while the direct micro-only kernel currently remains `.pto`-only in this environment. diff --git a/ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md b/ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md new file mode 100644 index 00000000..42f4175f --- /dev/null +++ b/ptodsl/lib/a5/TILE_MICRO_CHECKLIST.md @@ -0,0 +1,43 @@ +# Tile Micro Coverage + +- Total public tile ops: `32` +- Implemented: `26` +- Partial: `1` +- Pending: `0` +- Blocked: `4` +- Not applicable: `1` + +| tile op | status | helper | note | +| --- | --- | --- | --- | +| `mov` | `implemented` | `mov_micro` | UB stage + vlds/vsts copy loop. | +| `add` | `implemented` | `add_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vadd/vsts lowering. | +| `sub` | `implemented` | `sub_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vsub/vsts lowering. | +| `div` | `implemented` | `div_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vdiv/vsts lowering. | +| `mul` | `implemented` | `mul_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vmul/vsts lowering. | +| `or_` | `implemented` | `or_micro` | UB stage + constexpr-specialized TBinOp-style vlds/vor/vsts lowering. | +| `gather` | `partial` | `gather_micro` | Indexed gather is implemented via vgather2 for same-width source/index pairs; mask-pattern gather still needs unsupported vsqz-style micro support. | +| `exp` | `implemented` | `exp_micro` | UB stage + vlds/vexp/vsts loop. | +| `log` | `implemented` | `log_micro` | UB stage + vlds/vln/vsts loop. | +| `relu` | `implemented` | `relu_micro` | UB stage + vlds/vrelu/vsts loop. | +| `abs` | `implemented` | `abs_micro` | UB stage + vlds/vabs/vsts loop. | +| `sqrt` | `implemented` | `sqrt_micro` | UB stage + vlds/vsqrt/vsts loop. | +| `rsqrt` | `implemented` | `rsqrt_micro` | UB stage + vsqrt/vrec micro sequence. | +| `reciprocal` | `implemented` | `reciprocal_micro` | UB stage + vlds/vrec/vsts loop. | +| `matmul` | `blocked` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. | +| `matmul_bias` | `blocked` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. | +| `matmul_acc` | `blocked` | `-` | Cube/L0 path is not a pure vector-micro rewrite target. | +| `extract` | `blocked` | `-` | Layout/L0 extraction op, not a vector-micro compute rewrite. | +| `row_sum` | `implemented` | `row_sum_micro` | Static-shape row reduction via vcadd + point-store. | +| `row_min` | `implemented` | `row_min_micro` | Static-shape row reduction via vcmin + point-store. | +| `row_max` | `implemented` | `row_max_micro` | Static-shape row reduction via vcmax + point-store. | +| `row_expand` | `implemented` | `row_expand_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vsts. | +| `row_expand_sub` | `implemented` | `row_expand_sub_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vsub/vsts. | +| `row_expand_div` | `implemented` | `row_expand_div_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vdiv/vsts. | +| `row_expand_mul` | `implemented` | `row_expand_mul_micro` | Static-shape canonical broadcast via vldas/vldus/vdup/vmul/vsts. | +| `col_sum` | `implemented` | `col_sum_micro` | Static-shape TColReduceOps-style column reduction via vadd. | +| `col_min` | `implemented` | `col_min_micro` | Static-shape TColReduceOps-style column reduction via vmin. | +| `col_max` | `implemented` | `col_max_micro` | Static-shape TColReduceOps-style column reduction via vmax. | +| `col_expand` | `implemented` | `col_expand_micro` | Static-shape canonical broadcast via vlds/vsts replication. | +| `mrgsort` | `implemented` | `mrgsort_micro` | Single-list row-major merge sort via vmrgsort4. | +| `sort32` | `implemented` | `sort32_micro` | Static-shape block sort via vbitsort. | +| `subset` | `not_applicable` | `-` | View helper only, not a tile compute op. | diff --git a/ptodsl/lib/a5/__init__.py b/ptodsl/lib/a5/__init__.py new file mode 100644 index 00000000..61670f55 --- /dev/null +++ b/ptodsl/lib/a5/__init__.py @@ -0,0 +1,27 @@ +from . import ops +from .kernels import ( + KERNEL_BUILDERS, + build_cube_matmul, + build_elementwise_add, + build_micro_vector_copy, + build_mxfp8_matmul, + build_templated_elementwise_add, +) +from .ops import * +from .tile_micro_coverage import ( + TILE_MICRO_COVERAGE, + coverage_markdown, + coverage_summary, +) + +__all__ = list(ops.__all__) + [ + "KERNEL_BUILDERS", + "TILE_MICRO_COVERAGE", + "build_cube_matmul", + "build_elementwise_add", + "build_micro_vector_copy", + "build_mxfp8_matmul", + "build_templated_elementwise_add", + "coverage_markdown", + "coverage_summary", +] diff --git a/ptodsl/lib/a5/generated/a5_cube_matmul.pto b/ptodsl/lib/a5/generated/a5_cube_matmul.pto new file mode 100644 index 00000000..7f52f654 --- /dev/null +++ b/ptodsl/lib/a5/generated/a5_cube_matmul.pto @@ -0,0 +1,46 @@ +module { + func.func @a5_cube_matmul(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr) { + %c0 = arith.constant 0 : index + %c16 = arith.constant 16 : index + %c32 = arith.constant 32 : index + %c1 = arith.constant 1 : index + %0 = pto.make_tensor_view %arg0, shape = [%c16, %c32], strides = [%c32, %c1] : !pto.tensor_view + %c32_0 = arith.constant 32 : index + %c16_1 = arith.constant 16 : index + %c1_2 = arith.constant 1 : index + %1 = pto.make_tensor_view %arg1, shape = [%c32_0, %c16_1], strides = [%c16_1, %c1_2] : !pto.tensor_view + %c16_3 = arith.constant 16 : index + %c16_4 = arith.constant 16 : index + %c1_5 = arith.constant 1 : index + %2 = pto.make_tensor_view %arg2, shape = [%c16_3, %c16_4], strides = [%c16_4, %c1_5] : !pto.tensor_view + pto.section.cube { + %c0_6 = arith.constant 0 : index + %c0_7 = arith.constant 0 : index + %c16_8 = arith.constant 16 : index + %c32_9 = arith.constant 32 : index + %3 = pto.partition_view %0, offsets = [%c0_6, %c0_7], sizes = [%c16_8, %c32_9] : !pto.tensor_view -> !pto.partition_tensor_view<16x32xf16> + %4 = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%3 : !pto.partition_tensor_view<16x32xf16>) outs(%4 : !pto.tile_buf) + %c0_10 = arith.constant 0 : index + %c0_11 = arith.constant 0 : index + %c32_12 = arith.constant 32 : index + %c16_13 = arith.constant 16 : index + %5 = pto.partition_view %1, offsets = [%c0_10, %c0_11], sizes = [%c32_12, %c16_13] : !pto.tensor_view -> !pto.partition_tensor_view<32x16xf16> + %6 = pto.alloc_tile : !pto.tile_buf + pto.tload ins(%5 : !pto.partition_tensor_view<32x16xf16>) outs(%6 : !pto.tile_buf) + %7 = pto.alloc_tile : !pto.tile_buf + %8 = pto.alloc_tile : !pto.tile_buf + %9 = pto.alloc_tile : !pto.tile_buf + pto.textract ins(%4, %c0, %c0 : !pto.tile_buf, index, index) outs(%7 : !pto.tile_buf) + pto.tmov ins(%6 : !pto.tile_buf) outs(%8 : !pto.tile_buf) + pto.tmatmul ins(%7, %8 : !pto.tile_buf, !pto.tile_buf) outs(%9 : !pto.tile_buf) + %c0_14 = arith.constant 0 : index + %c0_15 = arith.constant 0 : index + %c16_16 = arith.constant 16 : index + %c16_17 = arith.constant 16 : index + %10 = pto.partition_view %2, offsets = [%c0_14, %c0_15], sizes = [%c16_16, %c16_17] : !pto.tensor_view -> !pto.partition_tensor_view<16x16xf32> + pto.tstore ins(%9 : !pto.tile_buf) outs(%10 : !pto.partition_tensor_view<16x16xf32>) + } + return + } +} diff --git a/ptodsl/lib/a5/generated/a5_elementwise_add.pto b/ptodsl/lib/a5/generated/a5_elementwise_add.pto new file mode 100644 index 00000000..598b5bfb --- /dev/null +++ b/ptodsl/lib/a5/generated/a5_elementwise_add.pto @@ -0,0 +1,50 @@ +module { + func.func @a5_elementwise_add(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: !pto.ptr, %arg3: index, %arg4: index) { + %c1 = arith.constant 1 : index + %0 = pto.make_tensor_view %arg0, shape = [%arg3, %arg4], strides = [%arg4, %c1] : !pto.tensor_view + %c1_0 = arith.constant 1 : index + %1 = pto.make_tensor_view %arg1, shape = [%arg3, %arg4], strides = [%arg4, %c1_0] : !pto.tensor_view + %c1_1 = arith.constant 1 : index + %2 = pto.make_tensor_view %arg2, shape = [%arg3, %arg4], strides = [%arg4, %c1_1] : !pto.tensor_view + %c0 = arith.constant 0 : index + %c0_2 = arith.constant 0 : index + %c32 = arith.constant 32 : index + %c32_3 = arith.constant 32 : index + %3 = pto.partition_view %0, offsets = [%c0, %c0_2], sizes = [%c32, %c32_3] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %c0_4 = arith.constant 0 : index + %c0_5 = arith.constant 0 : index + %c32_6 = arith.constant 32 : index + %c32_7 = arith.constant 32 : index + %4 = pto.partition_view %1, offsets = [%c0_4, %c0_5], sizes = [%c32_6, %c32_7] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + %c0_8 = arith.constant 0 : index + %c0_9 = arith.constant 0 : index + %c32_10 = arith.constant 32 : index + %c32_11 = arith.constant 32 : index + %5 = pto.partition_view %2, offsets = [%c0_8, %c0_9], sizes = [%c32_10, %c32_11] : !pto.tensor_view -> !pto.partition_tensor_view<32x32xf32> + pto.section.vector { + %c0_i64 = arith.constant 0 : i64 + %c4096_i64 = arith.constant 4096 : i64 + %c8192_i64 = arith.constant 8192 : i64 + %6 = pto.alloc_tile addr = %c0_i64 : !pto.tile_buf + %7 = pto.alloc_tile addr = %c4096_i64 : !pto.tile_buf + %8 = pto.alloc_tile addr = %c8192_i64 : !pto.tile_buf + pto.tload ins(%3 : !pto.partition_tensor_view<32x32xf32>) outs(%6 : !pto.tile_buf) + pto.tload ins(%4 : !pto.partition_tensor_view<32x32xf32>) outs(%7 : !pto.tile_buf) + %9 = pto.castptr %c0_i64 : i64 -> !pto.ptr + %10 = pto.castptr %c4096_i64 : i64 -> !pto.ptr + %11 = pto.castptr %c8192_i64 : i64 -> !pto.ptr + %12 = pto.pset_b32 "PAT_ALL" : !pto.mask + %c0_12 = arith.constant 0 : index + %c1024 = arith.constant 1024 : index + %c64 = arith.constant 64 : index + scf.for %arg5 = %c0_12 to %c1024 step %c64 { + %13 = pto.vlds %9[%arg5] : !pto.ptr -> !pto.vreg<64xf32> + %14 = pto.vlds %10[%arg5] : !pto.ptr -> !pto.vreg<64xf32> + %15 = pto.vadd %13, %14, %12 : !pto.vreg<64xf32>, !pto.vreg<64xf32>, !pto.mask -> !pto.vreg<64xf32> + pto.vsts %15, %11[%arg5], %12 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + pto.tstore ins(%8 : !pto.tile_buf) outs(%5 : !pto.partition_tensor_view<32x32xf32>) + } + return + } +} diff --git a/ptodsl/lib/a5/generated/a5_micro_vector_copy.pto b/ptodsl/lib/a5/generated/a5_micro_vector_copy.pto new file mode 100644 index 00000000..7cda605b --- /dev/null +++ b/ptodsl/lib/a5/generated/a5_micro_vector_copy.pto @@ -0,0 +1,10 @@ +module { + func.func @a5_micro_vector_copy(%arg0: !pto.ptr, %arg1: !pto.ptr, %arg2: index) { + pto.section.vector { + %0 = pto.vlds %arg0[%arg2] : !pto.ptr -> !pto.vreg<64xf32> + %1 = pto.pset_b32 "PAT_ALL" : !pto.mask + pto.vsts %0, %arg1[%arg2], %1 : !pto.vreg<64xf32>, !pto.ptr, !pto.mask + } + return + } +} diff --git a/ptodsl/lib/a5/kernels.py b/ptodsl/lib/a5/kernels.py new file mode 100644 index 00000000..4bd7e531 --- /dev/null +++ b/ptodsl/lib/a5/kernels.py @@ -0,0 +1,228 @@ +from mlir.dialects import pto as _raw_pto +from mlir.ir import IndexType + +from ... import Constexpr, pto, scalar as s, to_ir_module +from ...language import make_mxfp8 +from . import ops + + +def build_elementwise_add(*, rows=32, cols=32, tile_rows=32, tile_cols=32, dtype=None): + dtype = pto.float32 if dtype is None else dtype + + def meta_data(): + return { + "ptr_t": pto.ptr(dtype), + "index_t": IndexType.get(), + } + + @to_ir_module(meta_data=meta_data) + def a5_elementwise_add( + src0: "ptr_t", + src1: "ptr_t", + dst: "ptr_t", + n_rows: "index_t", + n_cols: "index_t", + ) -> None: + lhs = pto.make_tensor(src0, shape=[n_rows, n_cols], dtype=dtype) + rhs = pto.make_tensor(src1, shape=[n_rows, n_cols], dtype=dtype) + out = pto.make_tensor(dst, shape=[n_rows, n_cols], dtype=dtype) + + lhs_tile = lhs.slice([0, 0], [tile_rows, tile_cols]) + rhs_tile = rhs.slice([0, 0], [tile_rows, tile_cols]) + out_tile = out.slice([0, 0], [tile_rows, tile_cols]) + + with pto.vector_section(): + ops.add_micro( + lhs_tile, + rhs_tile, + out_tile, + dtype=dtype, + shape=[tile_rows, tile_cols], + ) + + return a5_elementwise_add + + +def build_templated_elementwise_add(*, dtype=None): + dtype = pto.float32 if dtype is None else dtype + + def meta_data(ROWS=32, COLS=32): + return { + "ptr_t": pto.ptr(dtype), + "shape": [ROWS, COLS], + } + + @to_ir_module(meta_data=meta_data) + def a5_templated_elementwise_add( + src0: "ptr_t", + src1: "ptr_t", + dst: "ptr_t", + ROWS: Constexpr[int] = 32, + COLS: Constexpr[int] = 32, + VF_IMPL: Constexpr[str] = ops.VF_IMPL_DEFAULT, + ) -> None: + lhs = pto.make_tensor(src0, shape=shape, dtype=dtype) + rhs = pto.make_tensor(src1, shape=shape, dtype=dtype) + out = pto.make_tensor(dst, shape=shape, dtype=dtype) + + with pto.vector_section(): + ops.add_micro( + lhs.slice([0, 0], shape), + rhs.slice([0, 0], shape), + out.slice([0, 0], shape), + dtype=dtype, + shape=shape, + impl=VF_IMPL, + ) + + return a5_templated_elementwise_add + + +def build_micro_vector_copy(*, lanes=64, dtype=None): + dtype = pto.float32 if dtype is None else dtype + + def meta_data(): + return { + "ptr_t": pto.ptr(dtype, space="VEC"), + "index_t": IndexType.get(), + } + + @to_ir_module(meta_data=meta_data) + def a5_micro_vector_copy(src: "ptr_t", dst: "ptr_t", offset: "index_t") -> None: + with pto.vector_section(): + ops.vector_copy(src, dst, offset, lanes=lanes, dtype=dtype) + + return a5_micro_vector_copy + + +def build_mxfp8_matmul(*, m=16, k=64, n=32, lhs_variant="e5m2", rhs_variant="e5m2"): + mx = make_mxfp8(lhs=lhs_variant, rhs=rhs_variant) + scale_k = mx.scale_k(k) + + def meta_data(): + return { + "ptr_lhs": pto.ptr(mx.lhs), + "ptr_rhs": pto.ptr(mx.rhs), + "ptr_scale": pto.ptr(mx.scale), + "ptr_bias": pto.ptr(mx.acc), + "ptr_out": pto.ptr(mx.acc), + } + + @to_ir_module(meta_data=meta_data) + def a5_mxfp8_matmul( + lhs_ptr: "ptr_lhs", + lhs_scale_ptr: "ptr_scale", + rhs_ptr: "ptr_rhs", + rhs_scale_ptr: "ptr_scale", + bias_ptr: "ptr_bias", + out_ptr: "ptr_out", + ) -> None: + lhs = pto.make_tensor(lhs_ptr, shape=[m, k], dtype=mx.lhs) + rhs = pto.make_tensor(rhs_ptr, shape=[k, n], dtype=mx.rhs) + lhs_scale = pto.make_tensor(lhs_scale_ptr, shape=[m, scale_k], dtype=mx.scale) + rhs_scale = pto.make_tensor(rhs_scale_ptr, shape=[scale_k, n], dtype=mx.scale) + bias = pto.make_tensor(bias_ptr, shape=[1, n], dtype=mx.acc) + out = pto.make_tensor(out_ptr, shape=[m, n], dtype=mx.acc) + + with pto.cube_section(): + lhs_tile = ops.load_tile( + lhs.slice([0, 0], [m, k]), dtype=mx.lhs, shape=[m, k], space="LEFT" + ) + rhs_tile = ops.load_tile( + rhs.slice([0, 0], [k, n]), dtype=mx.rhs, shape=[k, n], space="RIGHT" + ) + lhs_scale_tile = ops.load_tile( + lhs_scale.slice([0, 0], [m, scale_k]), + dtype=mx.scale, + shape=[m, scale_k], + space="SCALING", + config=pto.TileBufConfig( + blayout="RowMajor", + slayout="RowMajor", + s_fractal_size=_raw_pto.TileConfig.fractalMxSize, + ), + ) + rhs_scale_tile = ops.load_tile( + rhs_scale.slice([0, 0], [scale_k, n]), + dtype=mx.scale, + shape=[scale_k, n], + space="SCALING", + config=pto.TileBufConfig( + blayout="ColMajor", + slayout="ColMajor", + s_fractal_size=_raw_pto.TileConfig.fractalMxSize, + ), + ) + bias_tile = ops.load_tile( + bias.slice([0, 0], [1, n]), dtype=mx.acc, shape=[1, n], space="BIAS" + ) + acc_tile = pto.make_tile_buffer(mx.acc, [m, n], space="ACC").alloc() + ops.matmul_mx_bias( + lhs_tile, + lhs_scale_tile, + rhs_tile, + rhs_scale_tile, + bias_tile, + acc_tile, + ) + ops.store_tile(acc_tile, out.slice([0, 0], [m, n])) + + return a5_mxfp8_matmul + + +def build_cube_matmul( + *, m=16, k=32, n=16, lhs_dtype=None, rhs_dtype=None, acc_dtype=None +): + lhs_dtype = pto.float16 if lhs_dtype is None else lhs_dtype + rhs_dtype = pto.float16 if rhs_dtype is None else rhs_dtype + acc_dtype = pto.float32 if acc_dtype is None else acc_dtype + + def meta_data(): + return { + "ptr_lhs": pto.ptr(lhs_dtype), + "ptr_rhs": pto.ptr(rhs_dtype), + "ptr_out": pto.ptr(acc_dtype), + } + + @to_ir_module(meta_data=meta_data) + def a5_cube_matmul( + lhs_ptr: "ptr_lhs", rhs_ptr: "ptr_rhs", out_ptr: "ptr_out" + ) -> None: + c0 = s.const(0) + lhs = pto.make_tensor(lhs_ptr, shape=[m, k], dtype=lhs_dtype) + rhs = pto.make_tensor(rhs_ptr, shape=[k, n], dtype=rhs_dtype) + out = pto.make_tensor(out_ptr, shape=[m, n], dtype=acc_dtype) + + with pto.cube_section(): + lhs_mat = ops.load_tile( + lhs.slice([0, 0], [m, k]), dtype=lhs_dtype, shape=[m, k], space="MAT" + ) + rhs_mat = ops.load_tile( + rhs.slice([0, 0], [k, n]), dtype=rhs_dtype, shape=[k, n], space="MAT" + ) + lhs_tile = pto.make_tile_buffer(lhs_dtype, [m, k], space="LEFT").alloc() + rhs_tile = pto.make_tile_buffer(rhs_dtype, [k, n], space="RIGHT").alloc() + acc_tile = pto.make_tile_buffer(acc_dtype, [m, n], space="ACC").alloc() + ops.extract(lhs_mat, c0, c0, lhs_tile) + ops.move_tile(rhs_mat, rhs_tile) + ops.matmul(lhs_tile, rhs_tile, acc_tile) + ops.store_tile(acc_tile, out.slice([0, 0], [m, n])) + + return a5_cube_matmul + + +KERNEL_BUILDERS = { + "a5_elementwise_add": build_elementwise_add, + "a5_micro_vector_copy": build_micro_vector_copy, + "a5_cube_matmul": build_cube_matmul, +} + + +__all__ = [ + "KERNEL_BUILDERS", + "build_cube_matmul", + "build_elementwise_add", + "build_micro_vector_copy", + "build_mxfp8_matmul", + "build_templated_elementwise_add", +] diff --git a/ptodsl/lib/a5/ops.py b/ptodsl/lib/a5/ops.py new file mode 100644 index 00000000..996c6723 --- /dev/null +++ b/ptodsl/lib/a5/ops.py @@ -0,0 +1,2191 @@ +import builtins +import re + +from mlir.dialects import arith as _arith +from mlir.dialects import pto as _pto +from mlir.ir import IntegerAttr, IntegerType + +from ... import pto as _dsl_pto +from ... import scalar as _scalar +from ... import const_expr, range_constexpr +from ...api.scalar import _unwrap + +VF_IMPL_DEFAULT = "default" +VF_IMPL_1D_NO_POST_UPDATE = "1d_no_post_update" +VF_IMPL_1D_POST_UPDATE = "1d_post_update" +VF_IMPL_2D_NO_POST_UPDATE = "2d_no_post_update" +VF_IMPL_2D_POST_UPDATE = "2d_post_update" + + +_DTYPE_ALIAS_GROUPS = { + "f32": {"f32", "float32"}, + "f16": {"f16", "float16", "half"}, + "bf16": {"bf16", "bfloat16"}, + "i32": {"i32", "int32"}, + "u32": {"u32", "uint32"}, + "i16": {"i16", "int16"}, + "u16": {"u16", "uint16"}, + "i8": {"i8", "int8"}, + "u8": {"u8", "uint8"}, +} + + +def _call(op, *args, **kwargs): + return op( + *(_unwrap(arg) for arg in args), + **{name: _unwrap(value) for name, value in kwargs.items()}, + ) + + +def _cmp_mode_attr(mode): + if mode is None: + return None + if isinstance(mode, str): + return _pto.CmpModeAttr.get(getattr(_pto.CmpMode, mode.upper())) + return mode + + +def _const_i64(value): + i64 = IntegerType.get_signless(64) + return _arith.ConstantOp(i64, IntegerAttr.get(i64, value)).result + + +def _const_i32(value): + i32 = IntegerType.get_signless(32) + return _arith.ConstantOp(i32, IntegerAttr.get(i32, value)).result + + +def _const_float(dtype, value): + return _arith.ConstantOp(_scalar.resolve_type(dtype), value).result + + +def _dtype_token(dtype): + text = str(_scalar.resolve_type(dtype)).lower() + for canonical, aliases in _DTYPE_ALIAS_GROUPS.items(): + if any(alias in text for alias in aliases): + return canonical + raise ValueError(f"Unsupported dtype token for '{dtype}'.") + + +def _dtype_byte_width(dtype): + text = str(dtype) + if ( + "float32" in text + or "f32" in text + or "int32" in text + or "i32" in text + or "uint32" in text + or "u32" in text + ): + return 4 + if ( + "float16" in text + or "f16" in text + or "bfloat16" in text + or "bf16" in text + or "int16" in text + or "i16" in text + or "u16" in text + ): + return 2 + if "i8" in text or "u8" in text: + return 1 + raise ValueError(f"Unsupported dtype byte width for '{dtype}'.") + + +def _extract_static_tensor_shape(value): + raw = _unwrap(value) + type_obj = getattr(raw, "type", None) + if type_obj is None: + return None + text = str(type_obj) + match = re.search( + r"!pto\.(?:partition_)?tensor_view<(?P[^>]+)>|!pto\.tile_buf<[^,]+,\s*(?P[^>]+)>", + text, + ) + if not match: + return None + payload = match.group("payload") or match.group("tile_payload") + dims = re.findall(r"(\?|\d+)x", payload) + if not dims: + return None + shape = [] + for dim in dims: + if dim == "?": + return None + shape.append(int(dim)) + return shape + + +def _extract_tensor_dtype_token(value): + raw = _unwrap(value) + type_obj = getattr(raw, "type", None) + if type_obj is None: + return None + text = str(type_obj).lower() + for canonical, aliases in _DTYPE_ALIAS_GROUPS.items(): + if any(alias in text for alias in aliases): + return canonical + return None + + +def _require_supported_dtype(dtype, *, allowed, message): + try: + token = _dtype_token(dtype) + except ValueError as exc: + raise ValueError(message) from exc + if token not in allowed: + raise ValueError(message) + return token + + +def _require_view_shape(view, expected_shape, *, context, message): + actual_shape = _extract_static_tensor_shape(view) + if actual_shape is None: + return + if list(actual_shape) != list(expected_shape): + raise ValueError(f"{message} Expected {expected_shape}, got {actual_shape}.") + + +def _require_view_dtype(view, dtype, *, message): + actual_token = _extract_tensor_dtype_token(view) + if actual_token is None: + return + if actual_token != _dtype_token(dtype): + raise ValueError(message) + + +def _micro_lane_count(dtype): + return 256 // _dtype_byte_width(dtype) + + +def _resolve_lanes(dtype, lanes): + if lanes is None: + return _micro_lane_count(dtype) + return lanes + + +def _full_mask(dtype): + width = _dtype_byte_width(dtype) + if width == 4: + return _dsl_pto.pset_b32(_dsl_pto.MaskType(), "PAT_ALL") + if width == 2: + return _dsl_pto.pset_b16(_dsl_pto.MaskType(), "PAT_ALL") + if width == 1: + return _dsl_pto.pset_b8(_dsl_pto.MaskType(), "PAT_ALL") + raise ValueError(f"Unsupported dtype mask width for '{dtype}'.") + + +def _tail_mask(dtype, active_lanes): + i32 = IntegerType.get_signless(32) + width = _dtype_byte_width(dtype) + active = _const_i32(active_lanes) + if width == 4: + mask, _ = _dsl_pto.plt_b32(_dsl_pto.MaskType(), i32, active) + return mask + if width == 2: + mask, _ = _dsl_pto.plt_b16(_dsl_pto.MaskType(), i32, active) + return mask + if width == 1: + mask, _ = _dsl_pto.plt_b8(_dsl_pto.MaskType(), i32, active) + return mask + raise ValueError(f"Unsupported dtype tail mask width for '{dtype}'.") + + +def _mask_for_chunk(dtype, active_lanes): + lanes = _micro_lane_count(dtype) + if active_lanes == lanes: + return _full_mask(dtype) + return _tail_mask(dtype, active_lanes) + + +def _onept_dist(dtype): + width = _dtype_byte_width(dtype) + if width == 4: + return "ONEPT_B32" + if width == 2: + return "ONEPT_B16" + if width == 1: + return "ONEPT_B8" + raise ValueError(f"Unsupported dtype point-store width for '{dtype}'.") + + +def _normalize_vf_impl_kind(impl): + if impl is None: + return VF_IMPL_DEFAULT + + normalized = str(impl).strip().lower() + aliases = { + "default": VF_IMPL_DEFAULT, + "vfimpl_default": VF_IMPL_DEFAULT, + "1d_no_post_update": VF_IMPL_1D_NO_POST_UPDATE, + "vfimpl_1d_no_post_update": VF_IMPL_1D_NO_POST_UPDATE, + "1d_post_update": VF_IMPL_1D_POST_UPDATE, + "vfimpl_1d_post_update": VF_IMPL_1D_POST_UPDATE, + "2d_no_post_update": VF_IMPL_2D_NO_POST_UPDATE, + "vfimpl_2d_no_post_update": VF_IMPL_2D_NO_POST_UPDATE, + "2d_post_update": VF_IMPL_2D_POST_UPDATE, + "vfimpl_2d_post_update": VF_IMPL_2D_POST_UPDATE, + } + if normalized not in aliases: + supported = ", ".join(sorted(aliases)) + raise ValueError( + f"Unsupported VF impl kind '{impl}'. Expected one of: {supported}." + ) + return aliases[normalized] + + +def _alloc_like_view(view, *, dtype, shape, space, valid_shape=None, config=None): + return _dsl_pto.make_tile_buffer( + dtype, + shape, + space=space, + valid_shape=valid_shape, + config=config, + ).alloc() + + +def load_tile( + view, + tile_buffer=None, + *, + dtype=None, + shape=None, + space="VEC", + valid_shape=None, + config=None, +): + if tile_buffer is None: + if dtype is None or shape is None: + raise ValueError( + "`load_tile(...)` requires either `tile_buffer=` or both `dtype=` and `shape=`." + ) + tile_buffer = _alloc_like_view( + view, + dtype=dtype, + shape=shape, + space=space, + valid_shape=valid_shape, + config=config, + ) + _dsl_pto.load(view, tile_buffer) + return tile_buffer + + +def store_tile(tile_buffer, view): + _dsl_pto.store(tile_buffer, view) + return view + + +def move_tile(source, dest): + _call(_pto.TMovOp, None, source, dest) + return dest + + +def add(lhs, rhs, out): + _call(_pto.TAddOp, lhs, rhs, out) + return out + + +def add_micro( + lhs_view, + rhs_view, + out_view, + *, + dtype, + shape, + lanes=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + return _binary_micro( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name="vadd", + impl=impl, + ) + + +def sub_micro( + lhs_view, + rhs_view, + out_view, + *, + dtype, + shape, + lanes=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + return _binary_micro( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name="vsub", + impl=impl, + ) + + +def mul_micro( + lhs_view, + rhs_view, + out_view, + *, + dtype, + shape, + lanes=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + return _binary_micro( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name="vmul", + impl=impl, + ) + + +def div_micro( + lhs_view, + rhs_view, + out_view, + *, + dtype, + shape, + lanes=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + return _binary_micro( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name="vdiv", + impl=impl, + ) + + +def or_micro( + lhs_view, + rhs_view, + out_view, + *, + dtype, + shape, + lanes=None, + base_addr=0, + impl=VF_IMPL_DEFAULT, +): + return _binary_micro( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name="vor", + impl=impl, + ) + + +def mov_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name=None, + ) + + +def exp_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name="vexp", + ) + + +def log_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name="vln", + ) + + +def relu_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name="vrelu", + ) + + +def abs_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name="vabs", + ) + + +def sqrt_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name="vsqrt", + ) + + +def rsqrt_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _rsqrt_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + ) + + +def reciprocal_micro(src_view, out_view, *, dtype, shape, lanes=None, base_addr=0): + return _unary_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + lanes=lanes, + base_addr=base_addr, + op_name="vrec", + ) + + +def gather_micro( + src_view, + indices_view, + out_view, + *, + dtype, + index_dtype, + shape, + base_addr=0, +): + return _gather_micro( + src_view, + indices_view, + out_view, + dtype=dtype, + index_dtype=index_dtype, + shape=shape, + base_addr=base_addr, + ) + + +def col_expand_micro(src_view, out_view, *, dtype, shape, base_addr=0): + rows, cols = _check_col_expand_operands( + src_view, out_view, dtype=dtype, shape=shape, context="TCOLEXPAND" + ) + lanes = _micro_lane_count(dtype) + vreg_type = _dsl_pto.VRegType(lanes, dtype) + buf_bytes = rows * cols * _dtype_byte_width(dtype) + + src_addr = _const_i64(base_addr) + out_addr = _const_i64(base_addr + buf_bytes) + + src_tile = _dsl_pto.make_tile_buffer( + dtype, shape, space="VEC", valid_shape=[1, cols] + ).alloc(addr=src_addr) + out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) + + _dsl_pto.load(src_view, src_tile) + + src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) + out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) + + for col in range(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = _mask_for_chunk(dtype, active) + col_offset = _scalar.const(col) + vec = _dsl_pto.vlds(vreg_type, src_ptr, col_offset) + for row in range(rows): + dst_offset = _scalar.const(row * cols + col) + _dsl_pto.vsts(vec, out_ptr, dst_offset, mask) + + _dsl_pto.store(out_tile, out_view) + return out_view + + +def row_expand_micro(src_view, out_view, *, dtype, shape, base_addr=0): + rows, cols = _check_row_expand_operands( + src_view, out_view, dtype=dtype, shape=shape, context="TROWEXPAND" + ) + lanes = _micro_lane_count(dtype) + vreg_type = _dsl_pto.VRegType(lanes, dtype) + buf_bytes = rows * cols * _dtype_byte_width(dtype) + + src_addr = _const_i64(base_addr) + out_addr = _const_i64(base_addr + buf_bytes) + + src_tile = _dsl_pto.make_tile_buffer( + dtype, shape, space="VEC", valid_shape=[rows, 1] + ).alloc(addr=src_addr) + out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) + + _dsl_pto.load(src_view, src_tile) + + src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) + out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) + + for row in range(rows): + scalar_offset = _scalar.const(row * cols) + align = _dsl_pto.vldas(_dsl_pto.AlignType(), src_ptr, scalar_offset) + scalar_vec, _, _ = _dsl_pto.vldus( + vreg_type, + _dsl_pto.AlignType(), + _dsl_pto.ptr(dtype, space="VEC"), + src_ptr, + scalar_offset, + align, + ) + broadcast = _dsl_pto.vdup(vreg_type, scalar_vec, position="POS_LOWEST") + for col in range(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = _mask_for_chunk(dtype, active) + dst_offset = _scalar.const(row * cols + col) + _dsl_pto.vsts(broadcast, out_ptr, dst_offset, mask) + + _dsl_pto.store(out_tile, out_view) + return out_view + + +def row_expand_sub_micro( + base_view, expand_view, out_view, *, dtype, shape, base_addr=0 +): + return _row_expand_binary_micro( + base_view, + expand_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + op_name="vsub", + ) + + +def row_expand_mul_micro( + base_view, expand_view, out_view, *, dtype, shape, base_addr=0 +): + return _row_expand_binary_micro( + base_view, + expand_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + op_name="vmul", + ) + + +def row_expand_div_micro( + base_view, expand_view, out_view, *, dtype, shape, base_addr=0 +): + return _row_expand_binary_micro( + base_view, + expand_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + op_name="vdiv", + ) + + +def row_sum_micro(src_view, out_view, *, dtype, shape, base_addr=0): + return _row_reduce_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + reduce_op_name="vcadd", + combine_op_name="vadd", + init_value=0.0, + ) + + +def row_max_micro(src_view, out_view, *, dtype, shape, base_addr=0): + return _row_reduce_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + reduce_op_name="vcmax", + combine_op_name="vmax", + init_value=float("-inf"), + ) + + +def row_min_micro(src_view, out_view, *, dtype, shape, base_addr=0): + return _row_reduce_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + reduce_op_name="vcmin", + combine_op_name="vmin", + init_value=float("inf"), + ) + + +def col_sum_micro( + src_view, out_view, *, dtype, shape, base_addr=0, impl=VF_IMPL_DEFAULT +): + return _col_reduce_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + reduce_op_name="vadd", + impl=impl, + ) + + +def col_max_micro( + src_view, out_view, *, dtype, shape, base_addr=0, impl=VF_IMPL_DEFAULT +): + return _col_reduce_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + reduce_op_name="vmax", + impl=impl, + ) + + +def col_min_micro( + src_view, out_view, *, dtype, shape, base_addr=0, impl=VF_IMPL_DEFAULT +): + return _col_reduce_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + reduce_op_name="vmin", + impl=impl, + ) + + +def mrgsort_micro(src_view, out_view, *, dtype, shape, block_len, base_addr=0): + return _mrgsort_micro( + src_view, + out_view, + dtype=dtype, + shape=shape, + block_len=block_len, + base_addr=base_addr, + ) + + +def sort32_micro(src_view, idx_view, out_view, *, dtype, shape, base_addr=0): + return _sort32_micro( + src_view, + idx_view, + out_view, + dtype=dtype, + shape=shape, + base_addr=base_addr, + ) + + +def _require_static_matrix_shape(shape, *, context): + if len(shape) != 2 or any(not isinstance(dim, int) for dim in shape): + raise ValueError(f"{context} currently requires a static rank-2 integer shape.") + rows, cols = shape + if rows <= 0 or cols <= 0: + raise ValueError(f"{context} requires positive row/column sizes.") + return rows, cols + + +def _check_tbinop_operands(lhs_view, rhs_view, out_view, *, dtype, shape, context): + rows, cols = _require_static_matrix_shape(shape, context=context) + _require_supported_dtype( + dtype, + allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + message=f"Fix: {context} has invalid data type.", + ) + for view, label in ((lhs_view, "src0"), (rhs_view, "src1"), (out_view, "dst")): + _require_view_shape( + view, + [rows, cols], + context=context, + message=f"Fix: {context} input tile {label} valid shape mismatch with output tile dst shape.", + ) + _require_view_dtype( + view, + dtype, + message=f"Fix: {context} input tile src0, src1 and dst tile data type mismatch.", + ) + return rows, cols + + +def _check_row_expand_operands(src_view, out_view, *, dtype, shape, context): + rows, cols = _require_static_matrix_shape(shape, context=context) + _require_supported_dtype( + dtype, + allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + message=f"Fix: {context} data type must be b8/b16/b32", + ) + _require_view_shape( + src_view, + [rows, 1], + context=context, + message=f"Fix: {context} source valid shape must be [rows, 1].", + ) + _require_view_shape( + out_view, + [rows, cols], + context=context, + message=f"Fix: {context} output valid shape mismatch.", + ) + _require_view_dtype( + src_view, + dtype, + message=f"Fix: {context} input data type must be consistent with the output data type.", + ) + _require_view_dtype( + out_view, + dtype, + message=f"Fix: {context} input data type must be consistent with the output data type.", + ) + return rows, cols + + +def _check_col_expand_operands(src_view, out_view, *, dtype, shape, context): + rows, cols = _require_static_matrix_shape(shape, context=context) + _require_supported_dtype( + dtype, + allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + message=f"Fix: {context} data type must be b8/b16/b32", + ) + _require_view_shape( + src_view, + [1, cols], + context=context, + message=f"Fix: {context} input valid col must be consistent with output valid col.", + ) + _require_view_shape( + out_view, + [rows, cols], + context=context, + message=f"Fix: {context} output valid shape mismatch.", + ) + _require_view_dtype( + src_view, + dtype, + message=f"Fix: {context} input data type must be consistent with the output data type.", + ) + _require_view_dtype( + out_view, + dtype, + message=f"Fix: {context} input data type must be consistent with the output data type.", + ) + return rows, cols + + +def _check_row_reduce_operands(src_view, out_view, *, dtype, shape, context): + rows, cols = _require_static_matrix_shape(shape, context=context) + _require_supported_dtype( + dtype, + allowed={"f32", "f16", "i32", "i16"}, + message=( + "Row reduction only supports 'half', 'float', 'int32', or 'int16' data types. " + "Fix: Define TileDataIn with DType = half, float, int32, or int16." + ), + ) + _require_view_shape( + src_view, + [rows, cols], + context=context, + message="Fix: Ensure src valid shape matches [rows, cols].", + ) + _require_view_shape( + out_view, + [rows, 1], + context=context, + message="Fix: Pass dstValidRow = srcValidRows and use a single-column output tile.", + ) + _require_view_dtype( + src_view, + dtype, + message="Fix: Ensure TileDataOut uses the same DType as TileDataIn.", + ) + _require_view_dtype( + out_view, + dtype, + message="Fix: Ensure TileDataOut uses the same DType as TileDataIn.", + ) + return rows, cols + + +def _check_col_reduce_operands(src_view, out_view, *, dtype, shape, context): + rows, cols = _require_static_matrix_shape(shape, context=context) + _require_supported_dtype( + dtype, + allowed={"f32", "f16", "bf16", "i32", "u32", "i16", "u16", "i8", "u8"}, + message=f"Fix: {context} input data type is not supported by this instruction.", + ) + _require_view_shape( + src_view, + [rows, cols], + context=context, + message=f"Fix: {context} input shape mismatch.", + ) + _require_view_shape( + out_view, + [1, cols], + context=context, + message=f"Fix: {context} input valid row must be consistent with the output valid row.", + ) + _require_view_dtype( + src_view, + dtype, + message=f"Fix: {context} input data type must be consistent with the output data type.", + ) + _require_view_dtype( + out_view, + dtype, + message=f"Fix: {context} input data type must be consistent with the output data type.", + ) + return rows, cols + + +def _check_gather_operands( + src_view, indices_view, out_view, *, dtype, index_dtype, shape +): + rows, cols = _require_static_matrix_shape(shape, context="TGATHER") + dtype_token = _require_supported_dtype( + dtype, + allowed={"f32", "f16", "i32", "u32", "i16", "u16"}, + message="Fix: TGATHER Src data type must be int16_t/uint16_t/int32_t/uint32_t/half/float.", + ) + index_token = _require_supported_dtype( + index_dtype, + allowed={"i32", "u32", "i16", "u16"}, + message="Fix: TGATHER expect b16/b32", + ) + if _dtype_byte_width(dtype) != _dtype_byte_width(index_dtype): + raise ValueError( + "Fix: TGATHER micro lowering currently supports same-width source/index pairs only." + ) + for view, expected_shape, label in ( + (src_view, [rows, cols], "src"), + (indices_view, [rows, cols], "indices"), + (out_view, [rows, cols], "dst"), + ): + _require_view_shape( + view, + expected_shape, + context="TGATHER", + message=f"Fix: TGATHER {label} shape mismatch.", + ) + _require_view_dtype( + src_view, + dtype, + message="Fix: TGATHER expect same type size for dst and src", + ) + _require_view_dtype( + out_view, + dtype, + message="Fix: TGATHER expect same type size for dst and src", + ) + _require_view_dtype( + indices_view, + index_dtype, + message="Fix: TGATHER expect b16/b32", + ) + return rows, cols, dtype_token, index_token + + +def _check_mrgsort_operands(src_view, out_view, *, dtype, shape, block_len): + rows, cols = _require_static_matrix_shape(shape, context="TMRGSORT") + _require_supported_dtype( + dtype, + allowed={"f32", "f16"}, + message="TMrgsort: Unsupported data type! Supported types is half/float", + ) + if rows != 1: + raise ValueError("TMrgsort: the row of Destination and Source tile must be 1.") + if block_len <= 0 or cols % (block_len * 4) != 0: + raise ValueError("TMrgsort: src columns must be divisible by blockLen * 4.") + _require_view_shape( + src_view, + [rows, cols], + context="TMRGSORT", + message="TMrgsort: source tile shape mismatch.", + ) + _require_view_shape( + out_view, + [rows, cols], + context="TMRGSORT", + message="TMrgsort: destination tile shape mismatch.", + ) + _require_view_dtype( + src_view, + dtype, + message="TMrgsort: Destination and Source tile data types must be the same.", + ) + _require_view_dtype( + out_view, + dtype, + message="TMrgsort: Destination and Source tile data types must be the same.", + ) + return rows, cols + + +def _check_sort32_operands(src_view, idx_view, out_view, *, dtype, shape): + rows, cols = _require_static_matrix_shape(shape, context="TSORT32") + _require_supported_dtype( + dtype, + allowed={"f32", "f16"}, + message="Dst and src must be float or half.", + ) + out_cols = cols * (2 if _dtype_token(dtype) == "f32" else 4) + for view, expected_shape, label in ( + (src_view, [rows, cols], "src"), + (idx_view, [rows, cols], "idx"), + (out_view, [rows, out_cols], "dst"), + ): + _require_view_shape( + view, + expected_shape, + context="TSORT32", + message=f"TSORT32 {label} shape mismatch.", + ) + _require_view_dtype( + src_view, + dtype, + message="Dst and src mube be same.", + ) + _require_view_dtype( + out_view, + dtype, + message="Dst and src mube be same.", + ) + _require_view_dtype( + idx_view, + _dsl_pto.uint32, + message="Idx must be uint32_t.", + ) + if cols % 32 != 0: + raise ValueError( + "TSORT32 micro lowering currently requires column count divisible by 32." + ) + return rows, cols, out_cols + + +def _row_expand_binary_micro( + base_view, expand_view, out_view, *, dtype, shape, base_addr, op_name +): + rows, cols = _check_row_expand_operands( + expand_view, + out_view, + dtype=dtype, + shape=shape, + context=f"TROWEXPAND_{op_name[1:].upper()}", + ) + _require_view_shape( + base_view, + [rows, cols], + context=op_name, + message=f"Fix: TROWEXPAND_{op_name[1:].upper()} base input valid shape mismatch with output tile dst shape.", + ) + _require_view_dtype( + base_view, + dtype, + message=f"Fix: TROWEXPAND_{op_name[1:].upper()} input data type must be consistent with the output data type.", + ) + lanes = _micro_lane_count(dtype) + vreg_type = _dsl_pto.VRegType(lanes, dtype) + buf_bytes = rows * cols * _dtype_byte_width(dtype) + + base_addr_value = _const_i64(base_addr) + expand_addr_value = _const_i64(base_addr + buf_bytes) + out_addr_value = _const_i64(base_addr + buf_bytes * 2) + + base_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc( + addr=base_addr_value + ) + expand_tile = _dsl_pto.make_tile_buffer( + dtype, shape, space="VEC", valid_shape=[rows, 1] + ).alloc(addr=expand_addr_value) + out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc( + addr=out_addr_value + ) + + _dsl_pto.load(base_view, base_tile) + _dsl_pto.load(expand_view, expand_tile) + + base_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), base_addr_value) + expand_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), expand_addr_value) + out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr_value) + micro_op = getattr(_dsl_pto, op_name) + + for row in range(rows): + scalar_offset = _scalar.const(row * cols) + align = _dsl_pto.vldas(_dsl_pto.AlignType(), expand_ptr, scalar_offset) + scalar_vec, _, _ = _dsl_pto.vldus( + vreg_type, + _dsl_pto.AlignType(), + _dsl_pto.ptr(dtype, space="VEC"), + expand_ptr, + scalar_offset, + align, + ) + broadcast = _dsl_pto.vdup(vreg_type, scalar_vec, position="POS_LOWEST") + for col in range(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = _mask_for_chunk(dtype, active) + row_offset = _scalar.const(row * cols + col) + base_vec = _dsl_pto.vlds(vreg_type, base_ptr, row_offset) + out_vec = micro_op(vreg_type, base_vec, broadcast, mask) + _dsl_pto.vsts(out_vec, out_ptr, row_offset, mask) + + _dsl_pto.store(out_tile, out_view) + return out_view + + +def _row_reduce_micro( + src_view, + out_view, + *, + dtype, + shape, + base_addr, + reduce_op_name, + combine_op_name, + init_value, +): + rows, cols = _check_row_reduce_operands( + src_view, out_view, dtype=dtype, shape=shape, context="TROWREDUCE" + ) + width = _dtype_byte_width(dtype) + if width not in {2, 4}: + raise ValueError(f"{reduce_op_name} currently supports only float16/float32.") + + lanes = _micro_lane_count(dtype) + vreg_type = _dsl_pto.VRegType(lanes, dtype) + buf_bytes = rows * cols * width + + src_addr = _const_i64(base_addr) + out_addr = _const_i64(base_addr + buf_bytes) + + src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) + out_tile = _dsl_pto.make_tile_buffer( + dtype, shape, space="VEC", valid_shape=[rows, 1] + ).alloc(addr=out_addr) + + _dsl_pto.load(src_view, src_tile) + + src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) + out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) + reduce_op = getattr(_dsl_pto, reduce_op_name) + combine_op = getattr(_dsl_pto, combine_op_name) + full_mask = _full_mask(dtype) + point_mask = _tail_mask(dtype, 1) + init_scalar = _const_float(dtype, init_value) + + for row in range(rows): + accum = _dsl_pto.vbr(vreg_type, init_scalar) + for col in range(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = _mask_for_chunk(dtype, active) + offset = _scalar.const(row * cols + col) + vec = _dsl_pto.vlds(vreg_type, src_ptr, offset) + reduced = reduce_op(vreg_type, vec, mask) + accum = combine_op(vreg_type, accum, reduced, full_mask) + out_offset = _scalar.const(row * cols) + _dsl_pto.vsts(accum, out_ptr, out_offset, point_mask, dist=_onept_dist(dtype)) + + _dsl_pto.store(out_tile, out_view) + return out_view + + +def _col_reduce_micro( + src_view, + out_view, + *, + dtype, + shape, + base_addr, + reduce_op_name, + impl, +): + rows, cols = _check_col_reduce_operands( + src_view, out_view, dtype=dtype, shape=shape, context="TCOLREDUCE" + ) + lanes = _micro_lane_count(dtype) + buf_bytes = rows * cols * _dtype_byte_width(dtype) + + src_addr = _const_i64(base_addr) + out_addr = _const_i64(base_addr + buf_bytes) + + src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) + out_tile = _dsl_pto.make_tile_buffer( + dtype, [1, cols], space="VEC", valid_shape=[1, cols] + ).alloc(addr=out_addr) + + _dsl_pto.load(src_view, src_tile) + + ptr_type = _dsl_pto.ptr(dtype, space="VEC") + vreg_type = _dsl_pto.VRegType(lanes, dtype) + src_ptr = _dsl_pto.castptr(ptr_type, src_addr) + out_ptr = _dsl_pto.castptr(ptr_type, out_addr) + reduce_op = getattr(_dsl_pto, reduce_op_name) + impl_kind = _normalize_vf_impl_kind(impl) + if const_expr(impl_kind == VF_IMPL_DEFAULT): + impl_kind = VF_IMPL_1D_POST_UPDATE + + if const_expr(impl_kind in {VF_IMPL_1D_NO_POST_UPDATE, VF_IMPL_2D_NO_POST_UPDATE}): + _col_reduce_micro_no_post_update( + src_ptr, + out_ptr, + dtype=dtype, + rows=rows, + cols=cols, + lanes=lanes, + vreg_type=vreg_type, + reduce_op=reduce_op, + ) + elif const_expr(impl_kind in {VF_IMPL_1D_POST_UPDATE, VF_IMPL_2D_POST_UPDATE}): + _col_reduce_micro_post_update( + src_ptr, + out_ptr, + ptr_type=ptr_type, + dtype=dtype, + rows=rows, + cols=cols, + lanes=lanes, + vreg_type=vreg_type, + reduce_op=reduce_op, + ) + else: + raise ValueError(f"Unexpected normalized VF impl kind '{impl_kind}'.") + + _dsl_pto.store(out_tile, out_view) + return out_view + + +def _col_reduce_micro_no_post_update( + src_ptr, out_ptr, *, dtype, rows, cols, lanes, vreg_type, reduce_op +): + loop_pairs = (rows - 1) // 2 + remain = (rows - 1) % 2 + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = _mask_for_chunk(dtype, active) + accum = _dsl_pto.vlds(vreg_type, src_ptr, _scalar.const(col)) + for pair in range_constexpr(loop_pairs): + row0 = 2 * pair + 1 + row1 = 2 * pair + 2 + src0 = _dsl_pto.vlds(vreg_type, src_ptr, _scalar.const(col + row0 * cols)) + src1 = _dsl_pto.vlds(vreg_type, src_ptr, _scalar.const(col + row1 * cols)) + tmp = reduce_op(vreg_type, src0, src1, mask) + accum = reduce_op(vreg_type, accum, tmp, mask) + if const_expr(remain): + tail_row = 2 * loop_pairs + 1 + src_tail = _dsl_pto.vlds( + vreg_type, src_ptr, _scalar.const(col + tail_row * cols) + ) + accum = reduce_op(vreg_type, accum, src_tail, mask) + _dsl_pto.vsts(accum, out_ptr, _scalar.const(col), mask) + + +def _col_reduce_micro_post_update( + src_ptr, out_ptr, *, ptr_type, dtype, rows, cols, lanes, vreg_type, reduce_op +): + src_cursor = src_ptr + out_cursor = out_ptr + loop_pairs = (rows - 1) // 2 + remain = (rows - 1) % 2 + lane_step = _scalar.const(lanes) + pair_stride = _scalar.const(cols * 2) + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = _mask_for_chunk(dtype, active) + chunk_base = src_cursor + accum, src_cursor = _dsl_pto.vlds_post( + vreg_type, ptr_type, src_cursor, lane_step + ) + row0_ptr = _dsl_pto.addptr(chunk_base, _scalar.const(cols)) + row1_ptr = _dsl_pto.addptr(chunk_base, _scalar.const(cols * 2)) + for _ in range_constexpr(loop_pairs): + src0, row0_ptr = _dsl_pto.vlds_post( + vreg_type, ptr_type, row0_ptr, pair_stride + ) + src1, row1_ptr = _dsl_pto.vlds_post( + vreg_type, ptr_type, row1_ptr, pair_stride + ) + tmp = reduce_op(vreg_type, src0, src1, mask) + accum = reduce_op(vreg_type, accum, tmp, mask) + if const_expr(remain): + src_tail = _dsl_pto.vlds(vreg_type, row0_ptr, _scalar.const(0)) + accum = reduce_op(vreg_type, accum, src_tail, mask) + out_cursor = _dsl_pto.vsts_post(ptr_type, accum, out_cursor, lane_step, mask) + + +def _gather_micro( + src_view, + indices_view, + out_view, + *, + dtype, + index_dtype, + shape, + base_addr, +): + rows, cols, _, _ = _check_gather_operands( + src_view, + indices_view, + out_view, + dtype=dtype, + index_dtype=index_dtype, + shape=shape, + ) + src_bytes = rows * cols * _dtype_byte_width(dtype) + idx_bytes = rows * cols * _dtype_byte_width(index_dtype) + + src_addr = _const_i64(base_addr) + idx_addr = _const_i64(base_addr + src_bytes) + out_addr = _const_i64(base_addr + src_bytes + idx_bytes) + + src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) + idx_tile = _dsl_pto.make_tile_buffer(index_dtype, shape, space="VEC").alloc( + addr=idx_addr + ) + out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) + + _dsl_pto.load(src_view, src_tile) + _dsl_pto.load(indices_view, idx_tile) + + src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) + idx_ptr = _dsl_pto.castptr(_dsl_pto.ptr(index_dtype, space="VEC"), idx_addr) + out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) + lanes = _micro_lane_count(dtype) + vreg_type = _dsl_pto.VRegType(lanes, dtype) + index_vreg_type = _dsl_pto.VRegType(_micro_lane_count(index_dtype), index_dtype) + + for row in range_constexpr(rows): + row_base = row * cols + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + offset = _scalar.const(row_base + col) + mask = _mask_for_chunk(dtype, active) + idx_vec = _dsl_pto.vlds(index_vreg_type, idx_ptr, offset) + out_vec = _dsl_pto.vgather2( + vreg_type, src_ptr, idx_vec, _scalar.const(active) + ) + _dsl_pto.vsts(out_vec, out_ptr, offset, mask) + + _dsl_pto.store(out_tile, out_view) + return out_view + + +def _mrgsort_micro(src_view, out_view, *, dtype, shape, block_len, base_addr): + _, cols = _check_mrgsort_operands( + src_view, out_view, dtype=dtype, shape=shape, block_len=block_len + ) + src_addr = _const_i64(base_addr) + out_addr = _const_i64(base_addr + cols * _dtype_byte_width(dtype)) + + src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) + out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) + _dsl_pto.load(src_view, src_tile) + + ptr_type = _dsl_pto.ptr(dtype, space="VEC") + src_ptr = _dsl_pto.castptr(ptr_type, src_addr) + out_ptr = _dsl_pto.castptr(ptr_type, out_addr) + + src1_ptr = _dsl_pto.addptr(src_ptr, _scalar.const(block_len)) + src2_ptr = _dsl_pto.addptr(src_ptr, _scalar.const(block_len * 2)) + src3_ptr = _dsl_pto.addptr(src_ptr, _scalar.const(block_len * 3)) + + num_structures = (block_len * _dtype_byte_width(dtype)) >> 3 + count_value = ( + num_structures + | (num_structures << 16) + | (num_structures << 32) + | (num_structures << 48) + ) + repeat_times = cols // (block_len * 4) + config_value = repeat_times | (0b1111 << 8) + + _dsl_pto.vmrgsort4( + out_ptr, + src_ptr, + src1_ptr, + src2_ptr, + src3_ptr, + _const_i64(count_value), + _const_i64(config_value), + ) + _dsl_pto.store(out_tile, out_view) + return out_view + + +def _sort32_micro(src_view, idx_view, out_view, *, dtype, shape, base_addr): + rows, cols, out_cols = _check_sort32_operands( + src_view, idx_view, out_view, dtype=dtype, shape=shape + ) + src_bytes = rows * cols * _dtype_byte_width(dtype) + idx_bytes = rows * cols * 4 + + src_addr = _const_i64(base_addr) + idx_addr = _const_i64(base_addr + src_bytes) + out_addr = _const_i64(base_addr + src_bytes + idx_bytes) + + src_tile = _dsl_pto.make_tile_buffer(dtype, [rows, cols], space="VEC").alloc( + addr=src_addr + ) + idx_tile = _dsl_pto.make_tile_buffer( + _dsl_pto.uint32, [rows, cols], space="VEC" + ).alloc(addr=idx_addr) + out_tile = _dsl_pto.make_tile_buffer(dtype, [rows, out_cols], space="VEC").alloc( + addr=out_addr + ) + + _dsl_pto.load(src_view, src_tile) + _dsl_pto.load(idx_view, idx_tile) + + src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) + idx_ptr = _dsl_pto.castptr(_dsl_pto.ptr(_dsl_pto.uint32, space="VEC"), idx_addr) + out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) + repeat_times = _scalar.const(cols // 32) + + for row in range_constexpr(rows): + src_row = _dsl_pto.addptr(src_ptr, _scalar.const(row * cols)) + idx_row = _dsl_pto.addptr(idx_ptr, _scalar.const(row * cols)) + out_row = _dsl_pto.addptr(out_ptr, _scalar.const(row * out_cols)) + _dsl_pto.vbitsort(out_row, src_row, idx_row, repeat_times) + + _dsl_pto.store(out_tile, out_view) + return out_view + + +def _binary_micro( + lhs_view, rhs_view, out_view, *, dtype, shape, lanes, base_addr, op_name, impl +): + rows, cols = _check_tbinop_operands( + lhs_view, + rhs_view, + out_view, + dtype=dtype, + shape=shape, + context=op_name.upper().replace("V", "T", 1), + ) + lanes = _resolve_lanes(dtype, lanes) + element_count = rows * cols + buf_bytes = element_count * _dtype_byte_width(dtype) + lhs_addr = _const_i64(base_addr) + rhs_addr = _const_i64(base_addr + buf_bytes) + out_addr = _const_i64(base_addr + buf_bytes * 2) + + lhs_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=lhs_addr) + rhs_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=rhs_addr) + out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) + + _dsl_pto.load(lhs_view, lhs_tile) + _dsl_pto.load(rhs_view, rhs_tile) + + ptr_type = _dsl_pto.ptr(dtype, space="VEC") + vreg_type = _dsl_pto.VRegType(lanes, dtype) + lhs_ptr = _dsl_pto.castptr(ptr_type, lhs_addr) + rhs_ptr = _dsl_pto.castptr(ptr_type, rhs_addr) + out_ptr = _dsl_pto.castptr(ptr_type, out_addr) + micro_op = getattr(_dsl_pto, op_name) + impl_kind = _normalize_vf_impl_kind(impl) + is_contiguous = rows == 1 or cols == element_count + if const_expr(impl_kind == VF_IMPL_DEFAULT): + impl_kind = ( + VF_IMPL_1D_POST_UPDATE if is_contiguous else VF_IMPL_2D_NO_POST_UPDATE + ) + + if const_expr(impl_kind == VF_IMPL_1D_NO_POST_UPDATE): + _binary_micro_1d_no_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + dtype=dtype, + lanes=lanes, + element_count=element_count, + vreg_type=vreg_type, + micro_op=micro_op, + ) + elif const_expr(impl_kind == VF_IMPL_1D_POST_UPDATE): + _binary_micro_1d_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + ptr_type=ptr_type, + dtype=dtype, + lanes=lanes, + element_count=element_count, + vreg_type=vreg_type, + micro_op=micro_op, + ) + elif const_expr(impl_kind == VF_IMPL_2D_NO_POST_UPDATE): + _binary_micro_2d_no_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + dtype=dtype, + rows=rows, + cols=cols, + lanes=lanes, + vreg_type=vreg_type, + micro_op=micro_op, + ) + elif const_expr(impl_kind == VF_IMPL_2D_POST_UPDATE): + _binary_micro_2d_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + dtype=dtype, + rows=rows, + cols=cols, + lanes=lanes, + vreg_type=vreg_type, + micro_op=micro_op, + ) + else: + raise ValueError(f"Unexpected normalized VF impl kind '{impl_kind}'.") + + _dsl_pto.store(out_tile, out_view) + return out_view + + +def _binary_micro_1d_no_post_update( + lhs_ptr, rhs_ptr, out_ptr, *, dtype, lanes, element_count, vreg_type, micro_op +): + for offset in range_constexpr(0, element_count, lanes): + active = builtins.min(lanes, element_count - offset) + mask = _mask_for_chunk(dtype, active) + index = _scalar.const(offset) + lhs_vec = _dsl_pto.vlds(vreg_type, lhs_ptr, index) + rhs_vec = _dsl_pto.vlds(vreg_type, rhs_ptr, index) + out_vec = micro_op(vreg_type, lhs_vec, rhs_vec, mask) + _dsl_pto.vsts(out_vec, out_ptr, index, mask) + + +def _binary_micro_1d_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + *, + ptr_type, + dtype, + lanes, + element_count, + vreg_type, + micro_op, +): + lhs_cursor = lhs_ptr + rhs_cursor = rhs_ptr + out_cursor = out_ptr + lane_step = _scalar.const(lanes) + for offset in range_constexpr(0, element_count, lanes): + active = builtins.min(lanes, element_count - offset) + mask = _mask_for_chunk(dtype, active) + lhs_vec, lhs_cursor = _dsl_pto.vlds_post( + vreg_type, ptr_type, lhs_cursor, lane_step + ) + rhs_vec, rhs_cursor = _dsl_pto.vlds_post( + vreg_type, ptr_type, rhs_cursor, lane_step + ) + out_vec = micro_op(vreg_type, lhs_vec, rhs_vec, mask) + out_cursor = _dsl_pto.vsts_post(ptr_type, out_vec, out_cursor, lane_step, mask) + + +def _binary_micro_2d_no_post_update( + lhs_ptr, rhs_ptr, out_ptr, *, dtype, rows, cols, lanes, vreg_type, micro_op +): + for row in range_constexpr(rows): + row_base = row * cols + for col in range_constexpr(0, cols, lanes): + active = builtins.min(lanes, cols - col) + mask = _mask_for_chunk(dtype, active) + index = _scalar.const(row_base + col) + lhs_vec = _dsl_pto.vlds(vreg_type, lhs_ptr, index) + rhs_vec = _dsl_pto.vlds(vreg_type, rhs_ptr, index) + out_vec = micro_op(vreg_type, lhs_vec, rhs_vec, mask) + _dsl_pto.vsts(out_vec, out_ptr, index, mask) + + +def _binary_micro_2d_post_update( + lhs_ptr, rhs_ptr, out_ptr, *, dtype, rows, cols, lanes, vreg_type, micro_op +): + _binary_micro_2d_no_post_update( + lhs_ptr, + rhs_ptr, + out_ptr, + dtype=dtype, + rows=rows, + cols=cols, + lanes=lanes, + vreg_type=vreg_type, + micro_op=micro_op, + ) + + +def _rsqrt_micro(src_view, out_view, *, dtype, shape, lanes, base_addr): + if any(not isinstance(dim, int) for dim in shape): + raise ValueError( + "micro tile lowering currently requires a static integer shape." + ) + + lanes = _resolve_lanes(dtype, lanes) + element_count = 1 + for dim in shape: + element_count *= dim + + buf_bytes = element_count * _dtype_byte_width(dtype) + src_addr = _const_i64(base_addr) + out_addr = _const_i64(base_addr + buf_bytes) + + src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) + out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) + + _dsl_pto.load(src_view, src_tile) + + vreg_type = _dsl_pto.VRegType(lanes, dtype) + src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) + out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) + + for offset in range_constexpr(0, element_count, lanes): + active = builtins.min(lanes, element_count - offset) + mask = _mask_for_chunk(dtype, active) + index = _scalar.const(offset) + src_vec = _dsl_pto.vlds(vreg_type, src_ptr, index) + sqrt_vec = _dsl_pto.vsqrt(vreg_type, src_vec, mask) + out_vec = _dsl_pto.vrec(vreg_type, sqrt_vec, mask) + _dsl_pto.vsts(out_vec, out_ptr, index, mask) + + _dsl_pto.store(out_tile, out_view) + return out_view + + +def _unary_micro(src_view, out_view, *, dtype, shape, lanes, base_addr, op_name): + if any(not isinstance(dim, int) for dim in shape): + raise ValueError( + "micro tile lowering currently requires a static integer shape." + ) + + lanes = _resolve_lanes(dtype, lanes) + element_count = 1 + for dim in shape: + element_count *= dim + + buf_bytes = element_count * _dtype_byte_width(dtype) + src_addr = _const_i64(base_addr) + out_addr = _const_i64(base_addr + buf_bytes) + + src_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=src_addr) + out_tile = _dsl_pto.make_tile_buffer(dtype, shape, space="VEC").alloc(addr=out_addr) + + _dsl_pto.load(src_view, src_tile) + + src_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), src_addr) + out_ptr = _dsl_pto.castptr(_dsl_pto.ptr(dtype, space="VEC"), out_addr) + micro_op = getattr(_dsl_pto, op_name) if op_name is not None else None + + for offset in range_constexpr(0, element_count, lanes): + active = builtins.min(lanes, element_count - offset) + mask = _mask_for_chunk(dtype, active) + index = _scalar.const(offset) + src_vec = _dsl_pto.vlds(_dsl_pto.VRegType(lanes, dtype), src_ptr, index) + out_vec = ( + src_vec + if micro_op is None + else micro_op(_dsl_pto.VRegType(lanes, dtype), src_vec, mask) + ) + _dsl_pto.vsts(out_vec, out_ptr, index, mask) + + _dsl_pto.store(out_tile, out_view) + return out_view + + +def adds(src, scalar, out): + _call(_pto.TAddSOp, src, scalar, out) + return out + + +def sub(lhs, rhs, out): + _call(_pto.TSubOp, lhs, rhs, out) + return out + + +def subs(src, scalar, out): + _call(_pto.TSubSOp, src, scalar, out) + return out + + +def mul(lhs, rhs, out): + _call(_pto.TMulOp, lhs, rhs, out) + return out + + +def muls(src, scalar, out): + _call(_pto.TMulSOp, src, scalar, out) + return out + + +def div(lhs, rhs, out): + _call(_pto.TDivOp, lhs, rhs, out) + return out + + +def divs(src, scalar, out): + _call(_pto.TDivSOp, src, scalar, out) + return out + + +def max(lhs, rhs, out): + _call(_pto.TMaxOp, lhs, rhs, out) + return out + + +def maxs(src, scalar, out): + _call(_pto.TMaxSOp, src, scalar, out) + return out + + +def min(lhs, rhs, out): + _call(_pto.TMinOp, lhs, rhs, out) + return out + + +def mins(src, scalar, out): + _call(_pto.TMinSOp, src, scalar, out) + return out + + +def and_(lhs, rhs, out): + _call(_pto.TAndOp, lhs, rhs, out) + return out + + +def or_(lhs, rhs, out): + _call(_pto.TOrOp, lhs, rhs, out) + return out + + +def xor(lhs, rhs, out): + _call(_pto.TXorOp, lhs, rhs, out) + return out + + +def shl(lhs, rhs, out): + _call(_pto.TShlOp, lhs, rhs, out) + return out + + +def shls(src, scalar, out): + _call(_pto.TShlSOp, src, scalar, out) + return out + + +def shr(lhs, rhs, out): + _call(_pto.TShrOp, lhs, rhs, out) + return out + + +def shrs(src, scalar, out): + _call(_pto.TShrSOp, src, scalar, out) + return out + + +def compare(src0, src1, out, *, mode): + _call(_pto.TCmpOp, src0, src1, out, cmpMode=_cmp_mode_attr(mode)) + return out + + +def exp(src, out): + _call(_pto.TExpOp, src, out) + return out + + +def log(src, out): + _call(_pto.TLogOp, src, out) + return out + + +def relu(src, out): + _call(_pto.TReluOp, src, out) + return out + + +def abs(src, out): + _call(_pto.TAbsOp, src, out) + return out + + +def sqrt(src, out): + _call(_pto.TSqrtOp, src, out) + return out + + +def rsqrt(src, out): + _call(_pto.TRsqrtOp, src, out) + return out + + +def reciprocal(src, out): + _call(_pto.TRecipOp, src, out) + return out + + +def lrelu(src, slope, out): + _call(_pto.TLReluOp, src, slope, out) + return out + + +def gather(src, out, *, indices=None, mask_pattern=None): + kwargs = {} + if indices is not None: + kwargs["indices"] = indices + if mask_pattern is not None: + kwargs["maskPattern"] = _pto.MaskPatternAttr.get( + getattr(_pto.MaskPattern, mask_pattern) + ) + _call(_pto.TGatherOp, src, out, **kwargs) + return out + + +def scatter(src, indices, out): + _call(_pto.TScatterOp, src, indices, out) + return out + + +def select(mask, src0, src1, tmp, out): + _call(_pto.TSelOp, mask, src0, src1, tmp, out) + return out + + +def concat(src0, src1, out): + _call(_pto.TConcatOp, src0, src1, out) + return out + + +def extract(source, index_row, index_col, out): + _call(_pto.TExtractOp, source, index_row, index_col, out) + return out + + +def insert(source, index_row, index_col, out): + _call(_pto.TInsertOp, source, index_row, index_col, out) + return out + + +def row_sum(src, tmp, dst): + _call(_pto.TRowSumOp, src=src, tmp=tmp, dst=dst) + return dst + + +def row_min(src, tmp, dst): + _call(_pto.TRowMinOp, src=src, tmp=tmp, dst=dst) + return dst + + +def row_max(src, tmp, dst): + _call(_pto.TRowMaxOp, src=src, tmp=tmp, dst=dst) + return dst + + +def col_sum(src, tmp, dst, *, is_binary=True): + _call(_pto.TColSumOp, src=src, tmp=tmp, dst=dst, isBinary=is_binary) + return dst + + +def col_min(src, dst): + _call(_pto.TColMinOp, src=src, dst=dst) + return dst + + +def col_max(src, dst): + _call(_pto.TColMaxOp, src=src, dst=dst) + return dst + + +def row_expand(src, dst): + _call(_pto.TRowExpandOp, src=src, dst=dst) + return dst + + +def row_expand_sub(src0, src1, dst): + _call(_pto.TRowExpandSubOp, src0=src0, src1=src1, dst=dst) + return dst + + +def row_expand_mul(src0, src1, dst): + _call(_pto.TRowExpandMulOp, src0=src0, src1=src1, dst=dst) + return dst + + +def row_expand_div(src0, src1, dst): + _call(_pto.TRowExpandDivOp, src0=src0, src1=src1, dst=dst) + return dst + + +def col_expand(src, dst): + _call(_pto.TColExpandOp, src=src, dst=dst) + return dst + + +def col_expand_mul(src0, src1, dst): + _call(_pto.TColExpandMulOp, src0=src0, src1=src1, dst=dst) + return dst + + +def col_expand_max(src0, src1, dst): + _call(_pto.TColExpandMaxOp, src0=src0, src1=src1, dst=dst) + return dst + + +def col_expand_min(src0, src1, dst): + _call(_pto.TColExpandMinOp, src0=src0, src1=src1, dst=dst) + return dst + + +def trans(src, dst): + _call(_pto.TTransOp, src, dst) + return dst + + +def mrgsort(src, dst, block_len): + _call(_pto.TMrgSortOp, srcs=[src], dsts=[dst], blockLen=block_len) + return dst + + +def sort32(src, dst, idx): + _call(_pto.TSort32Op, src, dst, idx) + return dst + + +def matmul(lhs, rhs, out): + _call(_pto.TMatmulOp, None, lhs, rhs, out) + return out + + +def matmul_acc(acc, lhs, rhs, out): + _call(_pto.TMatmulAccOp, None, acc, lhs, rhs, out) + return out + + +def matmul_bias(lhs, rhs, bias, out): + _call(_pto.TMatmulBiasOp, None, lhs, rhs, bias, out) + return out + + +def matmul_mx(lhs, lhs_scale, rhs, rhs_scale, out): + _call(_pto.TMatmulMxOp, None, lhs, lhs_scale, rhs, rhs_scale, out) + return out + + +def matmul_mx_acc(acc, lhs, lhs_scale, rhs, rhs_scale, out): + _call(_pto.TMatmulMxAccOp, None, acc, lhs, lhs_scale, rhs, rhs_scale, out) + return out + + +def matmul_mx_bias(lhs, lhs_scale, rhs, rhs_scale, bias, out): + _call(_pto.TMatmulMxBiasOp, None, lhs, lhs_scale, rhs, rhs_scale, bias, out) + return out + + +def full_mask_b32(): + return _dsl_pto.pset_b32(_dsl_pto.MaskType(), "PAT_ALL") + + +def vload(ptr, offset, *, lanes=64, dtype=None): + dtype = _dsl_pto.float32 if dtype is None else dtype + return _dsl_pto.vlds(_dsl_pto.VRegType(lanes, dtype), ptr, offset) + + +def vstore(vector, ptr, offset, *, mask=None): + if mask is None: + mask = full_mask_b32() + _dsl_pto.vsts(vector, ptr, offset, mask) + return ptr + + +def vector_copy(src_ptr, dst_ptr, offset, *, lanes=64, dtype=None): + vec = vload(src_ptr, offset, lanes=lanes, dtype=dtype) + vstore(vec, dst_ptr, offset) + return vec + + +TLoad = load_tile +TStore = store_tile +TMov = move_tile +TAdd = add +TAddS = adds +TSub = sub +TSubS = subs +TMul = mul +TMulS = muls +TDiv = div +TDivS = divs +TMax = max +TMaxS = maxs +TMin = min +TMinS = mins +TAnd = and_ +TOr = or_ +TXor = xor +TShl = shl +TShlS = shls +TShr = shr +TShrS = shrs +TCmp = compare +TExp = exp +TLog = log +TRelu = relu +TAbs = abs +TSqrt = sqrt +TRsqrt = rsqrt +TRecip = reciprocal +TLRelu = lrelu +TGather = gather +TScatter = scatter +TSel = select +TConcat = concat +TExtract = extract +TInsert = insert +TRowSum = row_sum +TRowMin = row_min +TRowMax = row_max +TColSum = col_sum +TColMin = col_min +TColMax = col_max +TRowExpand = row_expand +TRowExpandSub = row_expand_sub +TRowExpandMul = row_expand_mul +TRowExpandDiv = row_expand_div +TColExpand = col_expand +TColExpandMul = col_expand_mul +TColExpandMax = col_expand_max +TColExpandMin = col_expand_min +TTrans = trans +TMrgSort = mrgsort +TSort32 = sort32 +TMatmul = matmul +TMatmulAcc = matmul_acc +TMatmulBias = matmul_bias +TMatmulMx = matmul_mx +TMatmulMxAcc = matmul_mx_acc +TMatmulMxBias = matmul_mx_bias + + +__all__ = [ + "VF_IMPL_DEFAULT", + "VF_IMPL_1D_NO_POST_UPDATE", + "VF_IMPL_1D_POST_UPDATE", + "VF_IMPL_2D_NO_POST_UPDATE", + "VF_IMPL_2D_POST_UPDATE", + "TAbs", + "TAdd", + "TAddS", + "TAnd", + "TColExpand", + "TColExpandMax", + "TColExpandMin", + "TColExpandMul", + "TColMax", + "TColMin", + "TColSum", + "TConcat", + "TCmp", + "TDiv", + "TDivS", + "TExp", + "TExtract", + "TGather", + "TInsert", + "TLRelu", + "TLoad", + "TLog", + "TMatmul", + "TMatmulAcc", + "TMatmulBias", + "TMatmulMx", + "TMatmulMxAcc", + "TMatmulMxBias", + "TMax", + "TMaxS", + "TMin", + "TMinS", + "TMov", + "TMrgSort", + "TMul", + "TMulS", + "TOr", + "TRecip", + "TRelu", + "TRowExpand", + "TRowExpandDiv", + "TRowExpandMul", + "TRowExpandSub", + "TRowMax", + "TRowMin", + "TRowSum", + "TRsqrt", + "TScatter", + "TSel", + "TShl", + "TShlS", + "TShr", + "TShrS", + "TSort32", + "TSqrt", + "TStore", + "TSub", + "TSubS", + "TTrans", + "TXor", + "add", + "add_micro", + "abs_micro", + "adds", + "and_", + "col_expand", + "col_expand_micro", + "col_expand_max", + "col_expand_min", + "col_expand_mul", + "col_max", + "col_max_micro", + "col_min", + "col_min_micro", + "col_sum", + "col_sum_micro", + "compare", + "concat", + "div", + "divs", + "exp", + "exp_micro", + "extract", + "full_mask_b32", + "gather", + "gather_micro", + "insert", + "load_tile", + "log", + "log_micro", + "lrelu", + "matmul", + "matmul_acc", + "matmul_bias", + "matmul_mx", + "matmul_mx_acc", + "matmul_mx_bias", + "max", + "maxs", + "min", + "mins", + "move_tile", + "mov_micro", + "mrgsort", + "mrgsort_micro", + "mul", + "muls", + "or_", + "reciprocal", + "reciprocal_micro", + "relu", + "relu_micro", + "row_expand", + "row_expand_div_micro", + "row_expand_micro", + "row_expand_mul_micro", + "row_expand_div", + "row_expand_sub_micro", + "row_expand_mul", + "row_expand_sub", + "row_max", + "row_max_micro", + "row_min", + "row_min_micro", + "row_sum", + "row_sum_micro", + "rsqrt", + "rsqrt_micro", + "scatter", + "select", + "shl", + "shls", + "shr", + "shrs", + "sort32", + "sort32_micro", + "sqrt", + "sqrt_micro", + "store_tile", + "sub", + "sub_micro", + "subs", + "div_micro", + "mul_micro", + "or_micro", + "trans", + "vector_copy", + "vload", + "vstore", + "xor", +] diff --git a/ptodsl/lib/a5/tile_micro_coverage.py b/ptodsl/lib/a5/tile_micro_coverage.py new file mode 100644 index 00000000..c3b98677 --- /dev/null +++ b/ptodsl/lib/a5/tile_micro_coverage.py @@ -0,0 +1,199 @@ +from ptodsl import tile + +TILE_MICRO_COVERAGE = { + "mov": { + "status": "implemented", + "helper": "mov_micro", + "note": "UB stage + vlds/vsts copy loop.", + }, + "add": { + "status": "implemented", + "helper": "add_micro", + "note": "UB stage + constexpr-specialized TBinOp-style vlds/vadd/vsts lowering.", + }, + "sub": { + "status": "implemented", + "helper": "sub_micro", + "note": "UB stage + constexpr-specialized TBinOp-style vlds/vsub/vsts lowering.", + }, + "div": { + "status": "implemented", + "helper": "div_micro", + "note": "UB stage + constexpr-specialized TBinOp-style vlds/vdiv/vsts lowering.", + }, + "mul": { + "status": "implemented", + "helper": "mul_micro", + "note": "UB stage + constexpr-specialized TBinOp-style vlds/vmul/vsts lowering.", + }, + "or_": { + "status": "implemented", + "helper": "or_micro", + "note": "UB stage + constexpr-specialized TBinOp-style vlds/vor/vsts lowering.", + }, + "gather": { + "status": "partial", + "helper": "gather_micro", + "note": "Indexed gather is implemented via vgather2 for same-width source/index pairs; mask-pattern gather still needs unsupported vsqz-style micro support.", + }, + "exp": { + "status": "implemented", + "helper": "exp_micro", + "note": "UB stage + vlds/vexp/vsts loop.", + }, + "log": { + "status": "implemented", + "helper": "log_micro", + "note": "UB stage + vlds/vln/vsts loop.", + }, + "relu": { + "status": "implemented", + "helper": "relu_micro", + "note": "UB stage + vlds/vrelu/vsts loop.", + }, + "abs": { + "status": "implemented", + "helper": "abs_micro", + "note": "UB stage + vlds/vabs/vsts loop.", + }, + "sqrt": { + "status": "implemented", + "helper": "sqrt_micro", + "note": "UB stage + vlds/vsqrt/vsts loop.", + }, + "rsqrt": { + "status": "implemented", + "helper": "rsqrt_micro", + "note": "UB stage + vsqrt/vrec micro sequence.", + }, + "reciprocal": { + "status": "implemented", + "helper": "reciprocal_micro", + "note": "UB stage + vlds/vrec/vsts loop.", + }, + "matmul": { + "status": "blocked", + "helper": None, + "note": "Cube/L0 path is not a pure vector-micro rewrite target.", + }, + "matmul_bias": { + "status": "blocked", + "helper": None, + "note": "Cube/L0 path is not a pure vector-micro rewrite target.", + }, + "matmul_acc": { + "status": "blocked", + "helper": None, + "note": "Cube/L0 path is not a pure vector-micro rewrite target.", + }, + "extract": { + "status": "blocked", + "helper": None, + "note": "Layout/L0 extraction op, not a vector-micro compute rewrite.", + }, + "row_sum": { + "status": "implemented", + "helper": "row_sum_micro", + "note": "Static-shape row reduction via vcadd + point-store.", + }, + "row_min": { + "status": "implemented", + "helper": "row_min_micro", + "note": "Static-shape row reduction via vcmin + point-store.", + }, + "row_max": { + "status": "implemented", + "helper": "row_max_micro", + "note": "Static-shape row reduction via vcmax + point-store.", + }, + "row_expand": { + "status": "implemented", + "helper": "row_expand_micro", + "note": "Static-shape canonical broadcast via vldas/vldus/vdup/vsts.", + }, + "row_expand_sub": { + "status": "implemented", + "helper": "row_expand_sub_micro", + "note": "Static-shape canonical broadcast via vldas/vldus/vdup/vsub/vsts.", + }, + "row_expand_div": { + "status": "implemented", + "helper": "row_expand_div_micro", + "note": "Static-shape canonical broadcast via vldas/vldus/vdup/vdiv/vsts.", + }, + "row_expand_mul": { + "status": "implemented", + "helper": "row_expand_mul_micro", + "note": "Static-shape canonical broadcast via vldas/vldus/vdup/vmul/vsts.", + }, + "col_sum": { + "status": "implemented", + "helper": "col_sum_micro", + "note": "Static-shape TColReduceOps-style column reduction via vadd.", + }, + "col_min": { + "status": "implemented", + "helper": "col_min_micro", + "note": "Static-shape TColReduceOps-style column reduction via vmin.", + }, + "col_max": { + "status": "implemented", + "helper": "col_max_micro", + "note": "Static-shape TColReduceOps-style column reduction via vmax.", + }, + "col_expand": { + "status": "implemented", + "helper": "col_expand_micro", + "note": "Static-shape canonical broadcast via vlds/vsts replication.", + }, + "mrgsort": { + "status": "implemented", + "helper": "mrgsort_micro", + "note": "Single-list row-major merge sort via vmrgsort4.", + }, + "sort32": { + "status": "implemented", + "helper": "sort32_micro", + "note": "Static-shape block sort via vbitsort.", + }, + "subset": { + "status": "not_applicable", + "helper": None, + "note": "View helper only, not a tile compute op.", + }, +} + + +def coverage_summary(): + counts = {} + for entry in TILE_MICRO_COVERAGE.values(): + status = entry["status"] + counts[status] = counts.get(status, 0) + 1 + return counts + + +def coverage_markdown(): + counts = coverage_summary() + lines = [ + "# Tile Micro Coverage", + "", + f"- Total public tile ops: `{len(tile.__all__)}`", + f"- Implemented: `{counts.get('implemented', 0)}`", + f"- Partial: `{counts.get('partial', 0)}`", + f"- Pending: `{counts.get('pending', 0)}`", + f"- Blocked: `{counts.get('blocked', 0)}`", + f"- Not applicable: `{counts.get('not_applicable', 0)}`", + "", + "| tile op | status | helper | note |", + "| --- | --- | --- | --- |", + ] + for name in tile.__all__: + entry = TILE_MICRO_COVERAGE[name] + helper = entry["helper"] or "-" + lines.append( + f"| `{name}` | `{entry['status']}` | `{helper}` | {entry['note']} |" + ) + return "\n".join(lines) + "\n" + + +__all__ = ["TILE_MICRO_COVERAGE", "coverage_markdown", "coverage_summary"] diff --git a/ptodsl/pto.py b/ptodsl/pto.py new file mode 100644 index 00000000..b69bfbb0 --- /dev/null +++ b/ptodsl/pto.py @@ -0,0 +1,6 @@ +from .api import pto as _pto +from .api.pto import __all__ + + +def __getattr__(name): + return getattr(_pto, name) diff --git a/ptodsl/scalar.py b/ptodsl/scalar.py new file mode 100644 index 00000000..07d48be8 --- /dev/null +++ b/ptodsl/scalar.py @@ -0,0 +1,6 @@ +from .api import scalar as _scalar +from .api.scalar import __all__ + + +def __getattr__(name): + return getattr(_scalar, name) diff --git a/ptodsl/test_util.py b/ptodsl/test_util.py index bf5badb3..95e1165d 100644 --- a/ptodsl/test_util.py +++ b/ptodsl/test_util.py @@ -1,19 +1,13 @@ -import os - - -DEVICE_ENV_VAR = "PTODSL_TEST_DEVICE_ID" -DEFAULT_DEVICE_ID = "0" -DEVICE_PREFIX = "npu:" - - -def get_test_device() -> str: - device_id = os.getenv(DEVICE_ENV_VAR) - if not device_id: - print( - f"Warning: {DEVICE_ENV_VAR} is not set; defaulting to {DEFAULT_DEVICE_ID}." - ) - device_id = DEFAULT_DEVICE_ID - - if device_id.startswith(DEVICE_PREFIX): - return device_id - return f"{DEVICE_PREFIX}{device_id}" +from .utils.test_util import ( + DEFAULT_DEVICE_ID, + DEVICE_ENV_VAR, + DEVICE_PREFIX, + get_test_device, +) + +__all__ = [ + "DEVICE_ENV_VAR", + "DEFAULT_DEVICE_ID", + "DEVICE_PREFIX", + "get_test_device", +] diff --git a/ptodsl/tile.py b/ptodsl/tile.py new file mode 100644 index 00000000..3d8658c2 --- /dev/null +++ b/ptodsl/tile.py @@ -0,0 +1,6 @@ +from .api import tile as _tile +from .api.tile import __all__ + + +def __getattr__(name): + return getattr(_tile, name) diff --git a/ptodsl/utils/__init__.py b/ptodsl/utils/__init__.py new file mode 100644 index 00000000..8fbbcda0 --- /dev/null +++ b/ptodsl/utils/__init__.py @@ -0,0 +1,4 @@ +from .bench import do_bench +from .test_util import get_test_device + +__all__ = ["do_bench", "get_test_device"] diff --git a/ptodsl/utils/bench.py b/ptodsl/utils/bench.py new file mode 100644 index 00000000..9bc0db7e --- /dev/null +++ b/ptodsl/utils/bench.py @@ -0,0 +1,56 @@ +from typing import Callable, List, Literal, Union + + +def do_bench( + fn: Callable, + warmup_iters: int = 5, + benchmark_iters: int = 15, + aggregation: Literal["mean", "none"] = "mean", + unit: Literal["s", "ms", "us", "ns"] = "us", + flush_cache: bool = True, +) -> Union[float, List[float]]: + """ + Benchmark a given function with warmup. + + Args: + fn: Function to benchmark. + warmup_iters: Number of warmup runs. + benchmark_iters: Number of benchmark runs. + aggregation: Aggregation mode for benchmark times. + unit: Time unit of the benchmarks. + flush_cache: if we should overwrite l2 cache between every iteration + Returns: + Runtime, or list of runtimes, in specified units. + """ + import torch + import torch_npu + + start_events = [torch.npu.Event(enable_timing=True) for _ in range(benchmark_iters)] + end_events = [torch.npu.Event(enable_timing=True) for _ in range(benchmark_iters)] + + # Allocate a 256 MB tensor which we write to every iteration to flush L2 cache + # https://github.com/tile-ai/tilelang/blob/main/tilelang/profiler/bench.py#L103 + cache_size = 256 * 1024 * 1024 + cache = torch.empty((cache_size), dtype=torch.int8).npu() + + for _ in range(warmup_iters): + fn() + torch_npu.npu.synchronize() + + # It's not easy to time a kernel in a way that satisfies the following two at the same time: + # 1) Ignores cache flushing, and 2) Ignoring kernel launch overhead. Here we ignore cache flushing. + for i in range(benchmark_iters): + if flush_cache: + cache.zero_() + start_events[i].record() + fn() + end_events[i].record() + + torch_npu.npu.synchronize() + factor = {"s": 1e-3, "ms": 1e0, "us": 1e3, "ns": 1e6}[unit] + times = [ + factor * start.elapsed_time(end) for start, end in zip(start_events, end_events) + ] + if aggregation == "mean": + return sum(times) / len(times) + return times diff --git a/ptodsl/utils/test_util.py b/ptodsl/utils/test_util.py new file mode 100644 index 00000000..bf5badb3 --- /dev/null +++ b/ptodsl/utils/test_util.py @@ -0,0 +1,19 @@ +import os + + +DEVICE_ENV_VAR = "PTODSL_TEST_DEVICE_ID" +DEFAULT_DEVICE_ID = "0" +DEVICE_PREFIX = "npu:" + + +def get_test_device() -> str: + device_id = os.getenv(DEVICE_ENV_VAR) + if not device_id: + print( + f"Warning: {DEVICE_ENV_VAR} is not set; defaulting to {DEFAULT_DEVICE_ID}." + ) + device_id = DEFAULT_DEVICE_ID + + if device_id.startswith(DEVICE_PREFIX): + return device_id + return f"{DEVICE_PREFIX}{device_id}" diff --git a/ptodsl/pyproject.toml b/pyproject.toml similarity index 61% rename from ptodsl/pyproject.toml rename to pyproject.toml index 2b57c249..de2df06b 100644 --- a/ptodsl/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,18 @@ authors = [ { name = "pto-dsl contributors" } ] +[project.optional-dependencies] +dev = ["matplotlib"] + [tool.setuptools] -packages = ["ptodsl"] +packages = [ + "ptodsl", + "ptodsl.api", + "ptodsl.compiler", + "ptodsl.lib", + "ptodsl.lib.a5", + "ptodsl.utils", +] [tool.setuptools.package-dir] -ptodsl = "." +ptodsl = "ptodsl" diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 00000000..ea508c64 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + require_npu: marks tests as requiring NPU hardware (deselect with '-m "not require_npu"') diff --git a/scripts/generate_a5_pto.py b/scripts/generate_a5_pto.py new file mode 100644 index 00000000..a7d97d6c --- /dev/null +++ b/scripts/generate_a5_pto.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 + +import argparse +import pathlib +import subprocess +import sys + + +_ROOT = pathlib.Path(__file__).resolve().parents[1] +if str(_ROOT) not in sys.path: + sys.path.insert(0, str(_ROOT)) + +from ptodsl.lib import a5 + + +_DEFAULT_OUTPUT_DIR = _ROOT / "ptodsl" / "lib" / "a5" / "generated" +_DEFAULT_PTOAS = _ROOT.parent / "PTOAS" / "build-src312" / "tools" / "ptoas" / "ptoas" + + +def emit_kernels(*, output_dir, ptoas_bin=None, emit_cpp=False): + output_dir.mkdir(parents=True, exist_ok=True) + generated = [] + for kernel_name, builder in a5.KERNEL_BUILDERS.items(): + module = builder() + pto_path = output_dir / f"{kernel_name}.pto" + pto_path.write_text(f"{module}\n", encoding="utf-8") + generated.append(pto_path) + + if emit_cpp: + if ptoas_bin is None: + raise ValueError("`emit_cpp=True` requires `ptoas_bin`.") + cpp_path = output_dir / f"{kernel_name}.cpp" + try: + subprocess.run( + [str(ptoas_bin), str(pto_path), "-o", str(cpp_path)], + check=True, + cwd=str(output_dir), + ) + except subprocess.CalledProcessError as exc: + print( + f"warning: failed to lower {pto_path.name} to C++ with ptoas: {exc}", + file=sys.stderr, + ) + return generated + + +def _parse_args(): + parser = argparse.ArgumentParser( + description="Generate PTODSL A5 translation artifacts as `.pto` files." + ) + parser.add_argument( + "--output-dir", + type=pathlib.Path, + default=_DEFAULT_OUTPUT_DIR, + help=f"Directory to write generated artifacts. Default: {_DEFAULT_OUTPUT_DIR}", + ) + parser.add_argument( + "--ptoas", + type=pathlib.Path, + default=_DEFAULT_PTOAS, + help=f"ptoas binary to use when `--emit-cpp` is set. Default: {_DEFAULT_PTOAS}", + ) + parser.add_argument( + "--emit-cpp", + action="store_true", + help="Also run ptoas and write `.cpp` files next to the generated `.pto` files.", + ) + return parser.parse_args() + + +def main(): + args = _parse_args() + ptoas_bin = args.ptoas if args.emit_cpp else None + generated = emit_kernels( + output_dir=args.output_dir, + ptoas_bin=ptoas_bin, + emit_cpp=args.emit_cpp, + ) + for path in generated: + print(path) + + +if __name__ == "__main__": + main() diff --git a/scripts/update_tile_micro_checklist.py b/scripts/update_tile_micro_checklist.py new file mode 100644 index 00000000..8eaaee9f --- /dev/null +++ b/scripts/update_tile_micro_checklist.py @@ -0,0 +1,21 @@ +#!/usr/bin/env python3 + +from pathlib import Path +import sys + + +_REPO_ROOT = Path(__file__).resolve().parents[1] +if str(_REPO_ROOT) not in sys.path: + sys.path.insert(0, str(_REPO_ROOT)) + +from ptodsl.lib.a5.tile_micro_coverage import coverage_markdown + + +def main(): + target = _REPO_ROOT / "ptodsl" / "lib" / "a5" / "TILE_MICRO_CHECKLIST.md" + target.write_text(coverage_markdown(), encoding="utf-8") + print(target) + + +if __name__ == "__main__": + main() diff --git a/tests/frontend/test_add_dynamic_ir.py b/tests/frontend/test_add_dynamic_ir.py index f3ded1a2..ebabd0ac 100644 --- a/tests/frontend/test_add_dynamic_ir.py +++ b/tests/frontend/test_add_dynamic_ir.py @@ -2,11 +2,10 @@ from mlir.ir import IndexType from mlir.dialects import arith, func, pto as _pto, scf -from ptodsl import to_ir_module +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s -import ptodsl.language as pto - -const = pto.const +const = s.const def meta_data(): @@ -51,12 +50,12 @@ def vec_add_1d_dynamic( vid = cidmul + sub_bid num_blocks = pto.get_block_num() - vid_idx = pto.index_cast(vid) - num_cores = pto.index_cast(num_blocks) - total_elements = pto.index_cast(argN) + vid_idx = s.index_cast(vid) + num_cores = s.index_cast(num_blocks) + total_elements = s.index_cast(argN) - num_tiles_global = pto.ceil_div(total_elements, c_tile) - num_tiles_per_core = pto.ceil_div(num_tiles_global, num_cores) + num_tiles_global = s.ceil_div(total_elements, c_tile) + num_tiles_per_core = s.ceil_div(num_tiles_global, num_cores) tile_offset_this_core = vid_idx * num_tiles_per_core with pto.vector_section(): @@ -73,29 +72,38 @@ def vec_add_1d_dynamic( need_truncate = tiles_end_this_core > num_tiles_global remaining_tiles = num_tiles_global - tile_offset_this_core - tiles_to_process = pto.select( + tiles_to_process = s.select( need_truncate, remaining_tiles, num_tiles_per_core ) elements_to_process = tiles_to_process * c_tile with pto.if_context(elements_to_process > c0): - for i in pto.for_range(c0, tiles_to_process, c1): + for i in pto.range(c0, tiles_to_process, c1): tile_offset_global = i + tile_offset_this_core offset_global = tile_offset_global * c_tile sv0 = pto.slice_view( - subtensor_type, source=tv0, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv0, + offsets=[offset_global], + sizes=[c_tile], ) sv1 = pto.slice_view( - subtensor_type, source=tv1, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv1, + offsets=[offset_global], + sizes=[c_tile], ) sv2 = pto.slice_view( - subtensor_type, source=tv2, offsets=[offset_global], sizes=[c_tile] + subtensor_type, + source=tv2, + offsets=[offset_global], + sizes=[c_tile], ) pto.load(sv0, tb0) pto.load(sv1, tb1) - pto.add(tb0, tb1, tb2) + tile.add(tb0, tb1, tb2) pto.store(tb2, sv2) @@ -118,7 +126,9 @@ def build_verbose(): sl = _pto.SLayoutAttr.get(_pto.SLayout.NoneBox) pd = _pto.PadValueAttr.get(_pto.PadValue.Null) cfg = _pto.TileBufConfigAttr.get(bl, sl, 512, pd) - tile_buf = _pto.TileBufType.get([1, tile_length], f32, vec, [1, tile_length], cfg) + tile_buf = _pto.TileBufType.get( + [1, tile_length], f32, vec, [1, tile_length], cfg + ) fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32, i32], []) with InsertionPoint(module.body): @@ -150,9 +160,15 @@ def build_verbose(): vec_section = _pto.SectionVectorOp() vec_block = vec_section.body.blocks.append() with InsertionPoint(vec_block): - tv0 = _pto.MakeTensorViewOp(tensor_view, arg0, [total_elements], [c1]).result - tv1 = _pto.MakeTensorViewOp(tensor_view, arg1, [total_elements], [c1]).result - tv2 = _pto.MakeTensorViewOp(tensor_view, arg2, [total_elements], [c1]).result + tv0 = _pto.MakeTensorViewOp( + tensor_view, arg0, [total_elements], [c1] + ).result + tv1 = _pto.MakeTensorViewOp( + tensor_view, arg1, [total_elements], [c1] + ).result + tv2 = _pto.MakeTensorViewOp( + tensor_view, arg2, [total_elements], [c1] + ).result tb0 = _pto.AllocTileOp(tile_buf).result tb1 = _pto.AllocTileOp(tile_buf).result @@ -184,8 +200,12 @@ def build_verbose(): work_if = scf.IfOp(has_elements) with InsertionPoint(work_if.then_block): for i in scf.for_(c0, tiles_to_process, c1): - tile_offset_global = arith.AddIOp(i, tile_offset_this_core).result - offset_global = arith.MulIOp(tile_offset_global, c_tile).result + tile_offset_global = arith.AddIOp( + i, tile_offset_this_core + ).result + offset_global = arith.MulIOp( + tile_offset_global, c_tile + ).result sv0 = _pto.PartitionViewOp( tile_view, tv0, offsets=[offset_global], sizes=[c_tile] diff --git a/tests/frontend/test_add_ir.py b/tests/frontend/test_add_ir.py index a6b5a936..4a321e39 100644 --- a/tests/frontend/test_add_ir.py +++ b/tests/frontend/test_add_ir.py @@ -1,9 +1,10 @@ from mlir.ir import Context, Location, Module, InsertionPoint, IntegerType from mlir.ir import F32Type, IndexType from ptodsl import to_ir_module -import ptodsl.language as pto +from ptodsl import pto, tile +from ptodsl import scalar as s -const = pto.const +const = s.const def meta_data(): @@ -14,7 +15,11 @@ def meta_data(): subtensor_type = pto.SubTensorType(shape=[32, 32], dtype=dtype) tile_cfg = pto.TileBufConfig() tile_type = pto.TileBufType( - shape=[32, 32], valid_shape=[-1, -1], dtype=dtype, memory_space="VEC", config=tile_cfg + shape=[32, 32], + valid_shape=[-1, -1], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, ) return { "ptr_type": ptr_type, @@ -43,18 +48,24 @@ def vec_add_2d_static( cidmul = cid * sub_bnum vid = cidmul + sub_bid - v_row_idx = pto.index_cast(arg_vrow_i32) - v_col_idx = pto.index_cast(arg_vcol_i32) + v_row_idx = s.index_cast(arg_vrow_i32) + v_col_idx = s.index_cast(arg_vcol_i32) tv0 = pto.as_tensor(tensor_type, ptr=arg0, shape=[c1280, c32], strides=[c32, c1]) tv1 = pto.as_tensor(tensor_type, ptr=arg1, shape=[c1280, c32], strides=[c32, c1]) tv2 = pto.as_tensor(tensor_type, ptr=arg2, shape=[c1280, c32], strides=[c32, c1]) - vid_idx = pto.index_cast(vid) + vid_idx = s.index_cast(vid) offset_row = vid_idx * c32 - sv0 = pto.slice_view(subtensor_type, source=tv0, offsets=[offset_row, c0], sizes=[c32, c32]) - sv1 = pto.slice_view(subtensor_type, source=tv1, offsets=[offset_row, c0], sizes=[c32, c32]) - sv2 = pto.slice_view(subtensor_type, source=tv2, offsets=[offset_row, c0], sizes=[c32, c32]) + sv0 = pto.slice_view( + subtensor_type, source=tv0, offsets=[offset_row, c0], sizes=[c32, c32] + ) + sv1 = pto.slice_view( + subtensor_type, source=tv1, offsets=[offset_row, c0], sizes=[c32, c32] + ) + sv2 = pto.slice_view( + subtensor_type, source=tv2, offsets=[offset_row, c0], sizes=[c32, c32] + ) with pto.vector_section(): tb0 = pto.alloc_tile(tile_type, valid_row=v_row_idx, valid_col=v_col_idx) @@ -63,32 +74,32 @@ def vec_add_2d_static( pto.load(sv0, tb0) pto.load(sv1, tb1) - pto.add(tb0, tb1, tb2) + tile.add(tb0, tb1, tb2) pto.store(tb2, sv2) def build(): - from mlir.dialects import func, arith, pto + from mlir.dialects import arith, func, pto as _pto with Context() as ctx, Location.unknown(): - pto.register_dialect(ctx, load=True) + _pto.register_dialect(ctx, load=True) m = Module.create() f32 = F32Type.get() i32 = IntegerType.get_signless(32) - ptr_f32 = pto.PtrType.get(f32) + ptr_f32 = _pto.PtrType.get(f32) - tv2_f32 = pto.TensorViewType.get(2, f32) - tile_view_32 = pto.PartitionTensorViewType.get([32, 32], f32) - vec = pto.AddressSpaceAttr.get(pto.AddressSpace.VEC) - bl = pto.BLayoutAttr.get(pto.BLayout.RowMajor) - sl = pto.SLayoutAttr.get(pto.SLayout.NoneBox) - pd = pto.PadValueAttr.get(pto.PadValue.Null) + tv2_f32 = _pto.TensorViewType.get(2, f32) + tile_view_32 = _pto.PartitionTensorViewType.get([32, 32], f32) + vec = _pto.AddressSpaceAttr.get(_pto.AddressSpace.VEC) + bl = _pto.BLayoutAttr.get(_pto.BLayout.RowMajor) + sl = _pto.SLayoutAttr.get(_pto.SLayout.NoneBox) + pd = _pto.PadValueAttr.get(_pto.PadValue.Null) - cfg = pto.TileBufConfigAttr.get(bl, sl, 512, pd) + cfg = _pto.TileBufConfigAttr.get(bl, sl, 512, pd) - tile_buf_dynamic = pto.TileBufType.get([32, 32], f32, vec, [-1, -1], cfg) + tile_buf_dynamic = _pto.TileBufType.get([32, 32], f32, vec, [-1, -1], cfg) fn_ty = func.FunctionType.get([ptr_f32, ptr_f32, ptr_f32, i32, i32], []) with InsertionPoint(m.body): @@ -103,42 +114,48 @@ def build(): arg0, arg1, arg2, arg_vrow_i32, arg_vcol_i32 = entry.arguments - cid = pto.GetBlockIdxOp().result - sub_bid = pto.GetSubBlockIdxOp().result - sub_bnum = pto.GetSubBlockNumOp().result + cid = _pto.GetBlockIdxOp().result + sub_bid = _pto.GetSubBlockIdxOp().result + sub_bnum = _pto.GetSubBlockNumOp().result cidmul = arith.MulIOp(cid, sub_bnum).result vid = arith.AddIOp(cidmul, sub_bid).result v_row_idx = arith.IndexCastOp(IndexType.get(), arg_vrow_i32).result v_col_idx = arith.IndexCastOp(IndexType.get(), arg_vcol_i32).result - tv0 = pto.MakeTensorViewOp(tv2_f32, arg0, [c1280, c32], [c32, c1]).result - tv1 = pto.MakeTensorViewOp(tv2_f32, arg1, [c1280, c32], [c32, c1]).result - tv2 = pto.MakeTensorViewOp(tv2_f32, arg2, [c1280, c32], [c32, c1]).result + tv0 = _pto.MakeTensorViewOp(tv2_f32, arg0, [c1280, c32], [c32, c1]).result + tv1 = _pto.MakeTensorViewOp(tv2_f32, arg1, [c1280, c32], [c32, c1]).result + tv2 = _pto.MakeTensorViewOp(tv2_f32, arg2, [c1280, c32], [c32, c1]).result vid_idx = arith.IndexCastOp(IndexType.get(), vid).result offset_row = arith.MulIOp(vid_idx, c32).result - sv0 = pto.PartitionViewOp( + sv0 = _pto.PartitionViewOp( tile_view_32, tv0, offsets=[offset_row, c0], sizes=[c32, c32] ).result - sv1 = pto.PartitionViewOp( + sv1 = _pto.PartitionViewOp( tile_view_32, tv1, offsets=[offset_row, c0], sizes=[c32, c32] ).result - sv2 = pto.PartitionViewOp( + sv2 = _pto.PartitionViewOp( tile_view_32, tv2, offsets=[offset_row, c0], sizes=[c32, c32] ).result - vec_section = pto.SectionVectorOp() + vec_section = _pto.SectionVectorOp() vec_block = vec_section.body.blocks.append() with InsertionPoint(vec_block): - tb0 = pto.AllocTileOp(tile_buf_dynamic, valid_row=v_row_idx, valid_col=v_col_idx).result - tb1 = pto.AllocTileOp(tile_buf_dynamic, valid_row=v_row_idx, valid_col=v_col_idx).result - tb2 = pto.AllocTileOp(tile_buf_dynamic, valid_row=v_row_idx, valid_col=v_col_idx).result - - pto.TLoadOp(None, sv0, tb0) - pto.TLoadOp(None, sv1, tb1) - pto.TAddOp(tb0, tb1, tb2) - pto.TStoreOp(None, tb2, sv2) + tb0 = _pto.AllocTileOp( + tile_buf_dynamic, valid_row=v_row_idx, valid_col=v_col_idx + ).result + tb1 = _pto.AllocTileOp( + tile_buf_dynamic, valid_row=v_row_idx, valid_col=v_col_idx + ).result + tb2 = _pto.AllocTileOp( + tile_buf_dynamic, valid_row=v_row_idx, valid_col=v_col_idx + ).result + + _pto.TLoadOp(None, sv0, tb0) + _pto.TLoadOp(None, sv1, tb1) + _pto.TAddOp(tb0, tb1, tb2) + _pto.TStoreOp(None, tb2, sv2) func.ReturnOp([]) diff --git a/tests/frontend/test_caller_gen.py b/tests/frontend/test_caller_gen.py index c03938fa..5a2e50ea 100644 --- a/tests/frontend/test_caller_gen.py +++ b/tests/frontend/test_caller_gen.py @@ -59,7 +59,37 @@ def mixed_kernel(data: "ptr_i8", count: "i64_type", idx: "index_dtype") -> None: 'extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *data, ' "int64_t count, int64_t idx)" ) in caller_cpp - assert "mixed_kernel<<>>((int8_t *)data, count, idx);" in caller_cpp + assert ( + "mixed_kernel<<>>((int8_t *)data, count, idx);" + in caller_cpp + ) + + +def test_generate_caller_cpp_maps_mxfp8_pointer_and_scalar_types(): + def mixed_mxfp8_kernel( + lhs: "ptr_e5m2", + lhs_scale: "ptr_e8m0", + alpha: "e4m3_type", + ) -> None: + return None + + wrapper = JitWrapper(mixed_mxfp8_kernel, meta_data=lambda: {}, block_dim=4) + wrapper._arg_types = [ + _FakeType("!pto.ptr"), + _FakeType("!pto.ptr"), + _FakeType("f8E4M3FN"), + ] + + caller_cpp = wrapper._generate_caller_cpp("generated.cpp") + + assert ( + 'extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *lhs, ' + "uint8_t *lhs_scale, uint8_t alpha)" + ) in caller_cpp + assert ( + "mixed_mxfp8_kernel<<>>((float8_e5m2_t *)lhs, " + "(float8_e8m0_t *)lhs_scale, alpha);" + ) in caller_cpp def test_generate_caller_cpp_for_dynamic_1d_add_signature(): diff --git a/tests/frontend/test_matmul_dynamic_ir.py b/tests/frontend/test_matmul_dynamic_ir.py index da7f6b5f..4dacbfed 100644 --- a/tests/frontend/test_matmul_dynamic_ir.py +++ b/tests/frontend/test_matmul_dynamic_ir.py @@ -1,9 +1,18 @@ from mlir.dialects import arith, func, pto as _pto, scf from mlir.dialects.arith import CmpIPredicate from mlir.dialects.pto import EVENT_ID0, TLOAD, TMATMUL, TMOV_M2L, TSTORE_ACC -from mlir.ir import Context, F32Type, IndexType, InsertionPoint, IntegerType, Location, Module +from mlir.ir import ( + Context, + F32Type, + IndexType, + InsertionPoint, + IntegerType, + Location, + Module, +) from ptodsl import to_ir_module -import ptodsl.language as pto +from ptodsl import pto, tile +from ptodsl import scalar as s def _idx_const(v: int): @@ -34,13 +43,25 @@ def meta_data(): tile_view_out = pto.SubTensorType(shape=[M, N], dtype=dtype) tile_view_bias = pto.SubTensorType(shape=[1, N], dtype=dtype) - tile_buf_aMat = pto.TileBufType(shape=[M, BASEK], dtype=dtype, memory_space="MAT") - tile_buf_bMat = pto.TileBufType(shape=[BASEK, N], dtype=dtype, memory_space="MAT") - tile_buf_biasData = pto.TileBufType(shape=[1, N], dtype=dtype, memory_space="MAT") - tile_buf_aTile = pto.TileBufType(shape=[M, BASEK], dtype=dtype, memory_space="LEFT") - tile_buf_bTile = pto.TileBufType(shape=[BASEK, N], dtype=dtype, memory_space="RIGHT") + tile_buf_aMat = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="MAT" + ) + tile_buf_bMat = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="MAT" + ) + tile_buf_biasData = pto.TileBufType( + shape=[1, N], dtype=dtype, memory_space="MAT" + ) + tile_buf_aTile = pto.TileBufType( + shape=[M, BASEK], dtype=dtype, memory_space="LEFT" + ) + tile_buf_bTile = pto.TileBufType( + shape=[BASEK, N], dtype=dtype, memory_space="RIGHT" + ) tile_buf_cTile = pto.TileBufType(shape=[M, N], dtype=dtype, memory_space="ACC") - tile_buf_biasTile = pto.TileBufType(shape=[1, N], dtype=dtype, memory_space="BIAS") + tile_buf_biasTile = pto.TileBufType( + shape=[1, N], dtype=dtype, memory_space="BIAS" + ) return { "ptr_type": ptr_dtype, @@ -60,7 +81,7 @@ def meta_data(): "tile_buf_biasTile": tile_buf_biasTile, } - const = pto.const + const = s.const @to_ir_module(meta_data=meta_data) def RunTMATMULSplitK( @@ -82,20 +103,28 @@ def RunTMATMULSplitK( cTileM = const(M) cTileN = const(N) - batch = pto.index_cast(batch_i32) + batch = s.index_cast(batch_i32) cBM = batch * cM - num_blocks = pto.index_cast(pto.get_block_num()) - batches_per_core = pto.ceil_div(batch, num_blocks) - bid = pto.index_cast(pto.get_block_idx()) + num_blocks = s.index_cast(pto.get_block_num()) + batches_per_core = s.ceil_div(batch, num_blocks) + bid = s.index_cast(pto.get_block_idx()) b_start = bid * batches_per_core b_end_unclamped = b_start + batches_per_core - b_end = pto.min_u(b_end_unclamped, batch) - - tvA = pto.as_tensor(tensor_type, ptr=a_ptr, shape=[cBM, cK], strides=[cK, c1]) - tvB = pto.as_tensor(tensor_type, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1]) - tvOut = pto.as_tensor(tensor_type, ptr=out_ptr, shape=[cBM, cN], strides=[cN, c1]) - tvBias = pto.as_tensor(tensor_type, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1]) + b_end = s.min_u(b_end_unclamped, batch) + + tvA = pto.as_tensor( + tensor_type, ptr=a_ptr, shape=[cBM, cK], strides=[cK, c1] + ) + tvB = pto.as_tensor( + tensor_type, ptr=b_ptr, shape=[cK, cN], strides=[cN, c1] + ) + tvOut = pto.as_tensor( + tensor_type, ptr=out_ptr, shape=[cBM, cN], strides=[cN, c1] + ) + tvBias = pto.as_tensor( + tensor_type, ptr=bias_ptr, shape=[c1, cN], strides=[cN, c1] + ) aMatTile = pto.alloc_tile(tile_buf_aMat) bMatTile = pto.alloc_tile(tile_buf_bMat) @@ -105,9 +134,9 @@ def RunTMATMULSplitK( cTile = pto.alloc_tile(tile_buf_cTile) biasTile = pto.alloc_tile(tile_buf_biasTile) - for b_idx in pto.for_range(b_start, b_end, c1): + for b_idx in pto.range(b_start, b_end, c1): row_off = b_idx * cM - for i in pto.for_range(c0, cIter, c1): + for i in pto.range(c0, cIter, c1): kOff = i * cBASEK svA = pto.slice_view( tile_view_a, @@ -135,25 +164,25 @@ def RunTMATMULSplitK( pto.record_wait_pair("LOAD", "MOV_M2L", event_id=0) - pto.mov(aMatTile, aTile) - pto.mov(bMatTile, bTile) + tile.mov(aMatTile, aTile) + tile.mov(bMatTile, bTile) with pto.if_context(isBias): - pto.mov(biasDataTile, biasTile) + tile.mov(biasDataTile, biasTile) pto.record_wait_pair("MOV_M2L", "MATMUL", event_id=0) - is_i0 = pto.eq(i, c0) + is_i0 = s.eq(i, c0) def _first_iter(): pto.cond( isBias, - lambda: pto.matmul_bias(aTile, bTile, biasTile, cTile), - lambda: pto.matmul(aTile, bTile, cTile), + lambda: tile.matmul_bias(aTile, bTile, biasTile, cTile), + lambda: tile.matmul(aTile, bTile, cTile), ) pto.cond( is_i0, _first_iter, - lambda: pto.matmul_acc(cTile, aTile, bTile, cTile), + lambda: tile.matmul_acc(cTile, aTile, bTile, cTile), ) pto.record_wait_pair("MATMUL", "LOAD", event_id=0) @@ -244,13 +273,21 @@ def build_verbose( tile_buf_aMat = pto.TileBufType.get([M, BASEK], dtype, mat, [M, BASEK], cfg_mat) tile_buf_bMat = pto.TileBufType.get([BASEK, N], dtype, mat, [BASEK, N], cfg_mat) - tile_buf_biasData = pto.TileBufType.get([1, N], dtype, mat, [1, N], cfg_mat_bias) - tile_buf_aTile = pto.TileBufType.get([M, BASEK], dtype, left, [M, BASEK], cfg_left) - tile_buf_bTile = pto.TileBufType.get([BASEK, N], dtype, right, [BASEK, N], cfg_right) + tile_buf_biasData = pto.TileBufType.get( + [1, N], dtype, mat, [1, N], cfg_mat_bias + ) + tile_buf_aTile = pto.TileBufType.get( + [M, BASEK], dtype, left, [M, BASEK], cfg_left + ) + tile_buf_bTile = pto.TileBufType.get( + [BASEK, N], dtype, right, [BASEK, N], cfg_right + ) tile_buf_cTile = pto.TileBufType.get([M, N], dtype, acc, [M, N], cfg_acc) tile_buf_biasTile = pto.TileBufType.get([1, N], dtype, bias, [1, N], cfg_bias) - fn_ty = func.FunctionType.get([ptr_dtype, ptr_dtype, ptr_dtype, ptr_dtype, i1, i32], []) + fn_ty = func.FunctionType.get( + [ptr_dtype, ptr_dtype, ptr_dtype, ptr_dtype, i1, i32], [] + ) with InsertionPoint(module.body): fn = func.FuncOp("RunTMATMULSplitK", fn_ty) entry = fn.add_entry_block() @@ -275,17 +312,29 @@ def build_verbose( batch = arith.IndexCastOp(IndexType.get(), batch_i32).result cBM = arith.MulIOp(batch, cM).result - num_blocks = arith.IndexCastOp(IndexType.get(), pto.GetBlockNumOp().result).result + num_blocks = arith.IndexCastOp( + IndexType.get(), pto.GetBlockNumOp().result + ).result batches_per_core = arith.CeilDivSIOp(batch, num_blocks).result - bid = arith.IndexCastOp(IndexType.get(), pto.GetBlockIdxOp().result).result + bid = arith.IndexCastOp( + IndexType.get(), pto.GetBlockIdxOp().result + ).result b_start = arith.MulIOp(bid, batches_per_core).result b_end_unclamped = arith.AddIOp(b_start, batches_per_core).result b_end = arith.MinUIOp(b_end_unclamped, batch).result - tvA = pto.MakeTensorViewOp(tensor_type, a_ptr, [cBM, cK], [cK, c1]).result - tvB = pto.MakeTensorViewOp(tensor_type, b_ptr, [cK, cN], [cN, c1]).result - tvOut = pto.MakeTensorViewOp(tensor_type, out_ptr, [cBM, cN], [cN, c1]).result - tvBias = pto.MakeTensorViewOp(tensor_type, bias_ptr, [c1, cN], [cN, c1]).result + tvA = pto.MakeTensorViewOp( + tensor_type, a_ptr, [cBM, cK], [cK, c1] + ).result + tvB = pto.MakeTensorViewOp( + tensor_type, b_ptr, [cK, cN], [cN, c1] + ).result + tvOut = pto.MakeTensorViewOp( + tensor_type, out_ptr, [cBM, cN], [cN, c1] + ).result + tvBias = pto.MakeTensorViewOp( + tensor_type, bias_ptr, [c1, cN], [cN, c1] + ).result aMatTile = pto.AllocTileOp(tile_buf_aMat).result bMatTile = pto.AllocTileOp(tile_buf_bMat).result @@ -303,7 +352,10 @@ def build_verbose( for i in scf.for_(c0, cIter, c1): kOff = arith.MulIOp(i, cBASEK).result svA = pto.PartitionViewOp( - tile_view_a, tvA, offsets=[row_off, kOff], sizes=[cTileM, cBASEK] + tile_view_a, + tvA, + offsets=[row_off, kOff], + sizes=[cTileM, cBASEK], ).result svB = pto.PartitionViewOp( tile_view_b, tvB, offsets=[kOff, c0], sizes=[cBASEK, cTileN] @@ -356,7 +408,10 @@ def build_verbose( pto.record_event(TMATMUL, TSTORE_ACC, EVENT_ID0) pto.wait_event(TMATMUL, TSTORE_ACC, EVENT_ID0) svOut = pto.PartitionViewOp( - tile_view_out, tvOut, offsets=[row_off, c0], sizes=[cTileM, cTileN] + tile_view_out, + tvOut, + offsets=[row_off, c0], + sizes=[cTileM, cTileN], ).result pto.TStoreOp(None, cTile, svOut) pto.record_event(TSTORE_ACC, TMATMUL, EVENT_ID0) diff --git a/tests/frontend/test_mxfp8_frontend.py b/tests/frontend/test_mxfp8_frontend.py new file mode 100644 index 00000000..76aebf8c --- /dev/null +++ b/tests/frontend/test_mxfp8_frontend.py @@ -0,0 +1,58 @@ +import types + +import ptodsl.language as pto + + +class _StubType: + @staticmethod + def get(): + return object() + + +def test_mxfp8_family_uses_e5m2_data_and_e8m0_scale(monkeypatch): + stub_ir = types.SimpleNamespace( + F32Type=_StubType, + Float8E5M2Type=_StubType, + Float8E8M0FNUType=_StubType, + Float8E4M3FNType=_StubType, + ) + monkeypatch.setattr(pto, "mlir_ir", stub_ir) + + mx = pto.mxfp8 + + assert mx.lhs is not None + assert mx.rhs is not None + assert mx.data is not None + assert mx.scale is not None + assert mx.acc is not None + assert mx.scale_k(64) == 2 + + +def test_float8_aliases_accept_common_mlir_ctor_names(monkeypatch): + stub_ir = types.SimpleNamespace( + F32Type=_StubType, + Float8E4M3FNType=_StubType, + Float8E5M2Type=_StubType, + Float8E8M0FNUType=_StubType, + ) + monkeypatch.setattr(pto, "mlir_ir", stub_ir) + + assert pto.fp8_e4m3 is not None + assert pto.fp8_e5m2 is not None + assert pto.fp8_e8m0 is not None + + +def test_make_mxfp8_accepts_mixed_lhs_rhs_variants(monkeypatch): + stub_ir = types.SimpleNamespace( + F32Type=_StubType, + Float8E4M3FNType=_StubType, + Float8E5M2Type=_StubType, + Float8E8M0FNUType=_StubType, + ) + monkeypatch.setattr(pto, "mlir_ir", stub_ir) + + mx = pto.make_mxfp8(lhs="e4m3", rhs="e5m2") + + assert mx.lhs is not None + assert mx.rhs is not None + assert mx.scale is not None diff --git a/tests/npu/elementwise_dynamic_multicore/.gitignore b/tests/npu/elementwise_binary_dynamic_multicore/.gitignore similarity index 100% rename from tests/npu/elementwise_dynamic_multicore/.gitignore rename to tests/npu/elementwise_binary_dynamic_multicore/.gitignore diff --git a/tests/npu/elementwise_dynamic_multicore/README.md b/tests/npu/elementwise_binary_dynamic_multicore/README.md similarity index 93% rename from tests/npu/elementwise_dynamic_multicore/README.md rename to tests/npu/elementwise_binary_dynamic_multicore/README.md index afd9c8bc..bd9f6145 100644 --- a/tests/npu/elementwise_dynamic_multicore/README.md +++ b/tests/npu/elementwise_binary_dynamic_multicore/README.md @@ -41,6 +41,8 @@ pytest test_builder.py -k "test_binary_2d_precision and add-float32" | sub | ✓ | ✓ | ✓ | | mul | ✓ | ✓ | ✓ | | div | ✓ | ✓ | skip | +| min | ✓ | ✓ | ✓ | +| max | ✓ | ✓ | ✓ | ## Compile a kernel manually @@ -57,4 +59,4 @@ Output: `__lib.so` in the same directory. - `ptoas` and `bisheng` on `PATH` - `/sources/pto-isa` present -- `torch_npu` installed \ No newline at end of file +- `torch_npu` installed diff --git a/tests/npu/elementwise_dynamic_multicore/builder.py b/tests/npu/elementwise_binary_dynamic_multicore/binary_builder.py similarity index 87% rename from tests/npu/elementwise_dynamic_multicore/builder.py rename to tests/npu/elementwise_binary_dynamic_multicore/binary_builder.py index c062bb3c..25c4b6a1 100644 --- a/tests/npu/elementwise_dynamic_multicore/builder.py +++ b/tests/npu/elementwise_binary_dynamic_multicore/binary_builder.py @@ -1,7 +1,7 @@ -from ptodsl import to_ir_module -import ptodsl.language as pto +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s -const = pto.const +const = s.const DTYPES = { @@ -63,12 +63,12 @@ def _1d( vid = cidmul + sub_bid num_blocks = pto.get_block_num() - vid_idx = pto.index_cast(vid) - num_cores = pto.index_cast(num_blocks) - total_elements = pto.index_cast(argN) + vid_idx = s.index_cast(vid) + num_cores = s.index_cast(num_blocks) + total_elements = s.index_cast(argN) - num_tiles_global = pto.ceil_div(total_elements, c_tile) - num_tiles_per_core = pto.ceil_div(num_tiles_global, num_cores) + num_tiles_global = s.ceil_div(total_elements, c_tile) + num_tiles_per_core = s.ceil_div(num_tiles_global, num_cores) tile_offset_this_core = vid_idx * num_tiles_per_core with pto.vector_section(): @@ -90,13 +90,13 @@ def _1d( tiles_end_this_core = tile_offset_this_core + num_tiles_per_core need_truncate = tiles_end_this_core > num_tiles_global remaining_tiles = num_tiles_global - tile_offset_this_core - tiles_to_process = pto.select( + tiles_to_process = s.select( need_truncate, remaining_tiles, num_tiles_per_core ) elements_to_process = tiles_to_process * c_tile with pto.if_context(elements_to_process > c0): - for i in pto.for_range(c0, tiles_to_process, c1): + for i in pto.range(c0, tiles_to_process, c1): tile_offset_global = i + tile_offset_this_core offset_global = tile_offset_global * c_tile @@ -145,15 +145,15 @@ def _2d( vid = cidmul + sub_bid num_blocks = pto.get_block_num() - vid_idx = pto.index_cast(vid) - num_cores = pto.index_cast(num_blocks) - rows = pto.index_cast(argM) - cols = pto.index_cast(argN) + vid_idx = s.index_cast(vid) + num_cores = s.index_cast(num_blocks) + rows = s.index_cast(argM) + cols = s.index_cast(argN) total_elements = rows * cols - rows_per_core = pto.ceil_div(rows, num_cores) + rows_per_core = s.ceil_div(rows, num_cores) row_start = vid_idx * rows_per_core - tiles_per_row = pto.ceil_div(cols, c_tile) + tiles_per_row = s.ceil_div(cols, c_tile) with pto.vector_section(): tv0 = pto.as_tensor( @@ -174,14 +174,12 @@ def _2d( rows_end = row_start + rows_per_core need_truncate = rows_end > rows remaining_rows = rows - row_start - rows_to_process = pto.select( - need_truncate, remaining_rows, rows_per_core - ) + rows_to_process = s.select(need_truncate, remaining_rows, rows_per_core) - for r in pto.for_range(c0, rows_to_process, c1): + for r in pto.range(c0, rows_to_process, c1): row_idx = r + row_start row_flat_offset = row_idx * cols - for c in pto.for_range(c0, tiles_per_row, c1): + for c in pto.range(c0, tiles_per_row, c1): col_offset = c * c_tile flat_offset = row_flat_offset + col_offset diff --git a/tests/npu/elementwise_dynamic_multicore/caller.py b/tests/npu/elementwise_binary_dynamic_multicore/caller.py similarity index 97% rename from tests/npu/elementwise_dynamic_multicore/caller.py rename to tests/npu/elementwise_binary_dynamic_multicore/caller.py index 9fa70f0b..b9d02196 100644 --- a/tests/npu/elementwise_dynamic_multicore/caller.py +++ b/tests/npu/elementwise_binary_dynamic_multicore/caller.py @@ -9,6 +9,8 @@ "sub": "float", "add": "float", "or": "int16_t", + "max": "float", + "min": "float", } diff --git a/tests/npu/elementwise_dynamic_multicore/clean.sh b/tests/npu/elementwise_binary_dynamic_multicore/clean.sh similarity index 100% rename from tests/npu/elementwise_dynamic_multicore/clean.sh rename to tests/npu/elementwise_binary_dynamic_multicore/clean.sh diff --git a/tests/npu/elementwise_dynamic_multicore/compile.sh b/tests/npu/elementwise_binary_dynamic_multicore/compile.sh similarity index 100% rename from tests/npu/elementwise_dynamic_multicore/compile.sh rename to tests/npu/elementwise_binary_dynamic_multicore/compile.sh diff --git a/tests/npu/elementwise_dynamic_multicore/gen_ir.py b/tests/npu/elementwise_binary_dynamic_multicore/gen_ir.py similarity index 77% rename from tests/npu/elementwise_dynamic_multicore/gen_ir.py rename to tests/npu/elementwise_binary_dynamic_multicore/gen_ir.py index 2562d76c..d1c9517c 100644 --- a/tests/npu/elementwise_dynamic_multicore/gen_ir.py +++ b/tests/npu/elementwise_binary_dynamic_multicore/gen_ir.py @@ -8,15 +8,17 @@ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) -import ptodsl.language as pto -from builder import build_binary_kernels +from ptodsl import tile +from binary_builder import build_binary_kernels _OPS = { - "add": pto.add, - "sub": pto.sub, - "mul": pto.mul, - "div": pto.div, - "or": pto.or_, + "add": tile.add, + "sub": tile.sub, + "mul": tile.mul, + "div": tile.div, + "or": tile.or_, + "max": tile.max, + "min": tile.min, } if __name__ == "__main__": diff --git a/tests/npu/elementwise_dynamic_multicore/test_builder.py b/tests/npu/elementwise_binary_dynamic_multicore/test_binary_builder.py similarity index 96% rename from tests/npu/elementwise_dynamic_multicore/test_builder.py rename to tests/npu/elementwise_binary_dynamic_multicore/test_binary_builder.py index ba188107..bf91ad50 100644 --- a/tests/npu/elementwise_dynamic_multicore/test_builder.py +++ b/tests/npu/elementwise_binary_dynamic_multicore/test_binary_builder.py @@ -17,7 +17,9 @@ ("sub", lambda x, y: x - y), ("mul", lambda x, y: x * y), ("div", lambda x, y: x / y), - #("or", lambda x, y: x | y), #TODO add back bitwise or test after fixing int16 support in the builder + ("max", lambda x, y: torch.max(x, y)), + ("min", lambda x, y: torch.min(x, y)), + # ("or", lambda x, y: x | y), #TODO add back bitwise or test after fixing int16 support in the builder ] DTYPES = ["float32", "float16", "int16"] @@ -131,6 +133,7 @@ def test_build_binary_kernels(compiled_lib): @pytest.mark.require_npu def test_binary_1d_precision(compiled_lib): import torch_npu + torch.npu.set_device(_DEVICE) ref_fn = compiled_lib["ref_fn"] torch_dtype = TORCH_DTYPES[compiled_lib["dtype"]] @@ -162,6 +165,7 @@ def test_binary_1d_precision(compiled_lib): @pytest.mark.require_npu def test_binary_2d_precision(compiled_lib): import torch_npu + torch.npu.set_device(_DEVICE) ref_fn = compiled_lib["ref_fn"] torch_dtype = TORCH_DTYPES[compiled_lib["dtype"]] diff --git a/tests/npu/elementwise_unary_dynamic_multicore/.gitignore b/tests/npu/elementwise_unary_dynamic_multicore/.gitignore new file mode 100644 index 00000000..d44846d3 --- /dev/null +++ b/tests/npu/elementwise_unary_dynamic_multicore/.gitignore @@ -0,0 +1,4 @@ +*.pto +*.cpp +*_lib.so +caller.cpp diff --git a/tests/npu/elementwise_unary_dynamic_multicore/caller.py b/tests/npu/elementwise_unary_dynamic_multicore/caller.py new file mode 100644 index 00000000..ab6cf518 --- /dev/null +++ b/tests/npu/elementwise_unary_dynamic_multicore/caller.py @@ -0,0 +1,35 @@ +"""Generate caller.cpp for a given unary op name.""" + +import sys + +_DTYPE_TO_CTYPE = { + "float32": "float", + "float16": "half", + "int32": "int32_t", + "int16": "int16_t", +} + +_BLOCK_DIM = 24 + + +def generate_caller(op_name, dtype="float32"): + ctype = _DTYPE_TO_CTYPE[dtype] + return f"""\ +#include "{op_name}_{dtype}.cpp" + +extern "C" void call_kernel( + void *stream, uint8_t *x, uint8_t *y, int32_t batch, int32_t n_cols) +{{ + _kernel<<<{_BLOCK_DIM}, nullptr, stream>>>( + ({ctype} *)x, ({ctype} *)y, batch, n_cols); +}} +""" + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python caller.py [dtype]", file=sys.stderr) + sys.exit(1) + op_name = sys.argv[1] + dtype = sys.argv[2] if len(sys.argv) > 2 else "float32" + print(generate_caller(op_name, dtype)) diff --git a/tests/npu/elementwise_unary_dynamic_multicore/clean.sh b/tests/npu/elementwise_unary_dynamic_multicore/clean.sh new file mode 100644 index 00000000..5ab34240 --- /dev/null +++ b/tests/npu/elementwise_unary_dynamic_multicore/clean.sh @@ -0,0 +1,3 @@ +#!/bin/bash +rm -f *.pto *.cpp *_lib.so caller.cpp +echo "Cleaned generated files." diff --git a/tests/npu/elementwise_unary_dynamic_multicore/compile.sh b/tests/npu/elementwise_unary_dynamic_multicore/compile.sh new file mode 100644 index 00000000..84ae81cc --- /dev/null +++ b/tests/npu/elementwise_unary_dynamic_multicore/compile.sh @@ -0,0 +1,33 @@ +#!/bin/bash +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +OP=${1:?Usage: compile.sh [dtype]} +DTYPE=${2:-float32} + +TMP=$(mktemp -d) +trap "rm -rf $TMP" EXIT + +python "$SCRIPT_DIR/gen_ir.py" "$OP" "$DTYPE" > "$TMP/${OP}_${DTYPE}.pto" +ptoas --enable-insert-sync "$TMP/${OP}_${DTYPE}.pto" -o "$TMP/${OP}_${DTYPE}.cpp" + +python "$SCRIPT_DIR/caller.py" "$OP" "$DTYPE" > "$TMP/caller.cpp" + +PTO_LIB_PATH=/sources/pto-isa +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + "$TMP/caller.cpp" \ + -o "$SCRIPT_DIR/${OP}_${DTYPE}_lib.so" + +echo "Built ${OP}_${DTYPE}_lib.so successfully." diff --git a/tests/npu/elementwise_unary_dynamic_multicore/gen_ir.py b/tests/npu/elementwise_unary_dynamic_multicore/gen_ir.py new file mode 100644 index 00000000..b6a1d778 --- /dev/null +++ b/tests/npu/elementwise_unary_dynamic_multicore/gen_ir.py @@ -0,0 +1,35 @@ +"""Print MLIR IR for a unary op at a given dtype. + +Usage: python gen_ir.py [dtype] +""" + +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from ptodsl import tile +from unary_builder import build_unary_kernel + +_OPS = { + "rsqrt": tile.rsqrt, + "sqrt": tile.sqrt, + "exp": tile.exp, + "log": tile.log, + "relu": tile.relu, + "abs": tile.abs, + "reciprocal": tile.reciprocal, +} + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python gen_ir.py [dtype]", file=sys.stderr) + sys.exit(1) + op_name = sys.argv[1] + dtype = sys.argv[2] if len(sys.argv) > 2 else "float32" + + if op_name not in _OPS: + print(f"Unknown op: {op_name}. Available: {list(_OPS)}", file=sys.stderr) + sys.exit(1) + + print(build_unary_kernel(op_name, _OPS[op_name], dtype=dtype)) diff --git a/tests/npu/elementwise_unary_dynamic_multicore/test_unary_builder.py b/tests/npu/elementwise_unary_dynamic_multicore/test_unary_builder.py new file mode 100644 index 00000000..65411aff --- /dev/null +++ b/tests/npu/elementwise_unary_dynamic_multicore/test_unary_builder.py @@ -0,0 +1,141 @@ +import os +import ctypes +import subprocess + +import pytest +import torch +from ptodsl.test_util import get_test_device + +torch.manual_seed(0) + +_DIR = os.path.dirname(os.path.abspath(__file__)) +_DEVICE = get_test_device() + +UNARY_OPS = [ + ("rsqrt", lambda x: x.rsqrt()), + ("sqrt", lambda x: x.sqrt()), + ("exp", lambda x: x.exp()), + ("log", lambda x: x.log()), + ("relu", lambda x: x.relu()), + ("abs", lambda x: x.abs()), + ("reciprocal", lambda x: x.reciprocal()), +] + +DTYPES = ["float32", "float16"] + +TORCH_DTYPES = { + "float32": torch.float32, + "float16": torch.float16, +} + +_SHAPE_LIST = [ + (1, 128), + (7, 1024), + (29, 512), + (32, 2048), + (65, 4096), + (200, 8192), +] + +_SHAPE_PARAMS = [ + pytest.param(batch, n_cols, id=f"batch{batch}-cols{n_cols}") + for batch, n_cols in _SHAPE_LIST +] + +_PARAMS = [ + pytest.param((op_name, ref_fn, dtype), id=f"{op_name}-{dtype}") + for op_name, ref_fn in UNARY_OPS + for dtype in DTYPES +] + + +@pytest.fixture(scope="session", params=_PARAMS) +def compiled_lib(request): + op_name, ref_fn, dtype = request.param + subprocess.check_call( + ["bash", os.path.join(_DIR, "compile.sh"), op_name, dtype], + cwd=_DIR, + ) + yield { + "op_name": op_name, + "ref_fn": ref_fn, + "dtype": dtype, + "lib_path": _lib_path(op_name, dtype), + } + os.remove(_lib_path(op_name, dtype)) + + +def _make_input(shape, device, dtype, op_name): + """Return a suitable input tensor for the given op. + + rsqrt: inputs in (1.0, 2.0] — keeps outputs near 1.0 + so float16 absolute error stays within 2e-3. + sqrt/log: inputs in (0.1, 1.1]. + exp: inputs in (-0.5, 0.5] to avoid float16 overflow. + relu/abs: inputs in (-1.0, 1.0] to exercise both signs. + """ + if op_name in {"rsqrt", "reciprocal"}: + return torch.rand(shape, device=device, dtype=dtype) + 1.0 + elif op_name in {"sqrt", "log"}: + return torch.rand(shape, device=device, dtype=dtype) + 0.1 + elif op_name == "exp": + return torch.rand(shape, device=device, dtype=dtype) - 0.5 + else: + return torch.rand(shape, device=device, dtype=dtype) * 2.0 - 1.0 + + +def _lib_to_func_unary(lib): + lib.call_kernel.argtypes = [ + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int32, + ctypes.c_int32, + ] + lib.call_kernel.restype = None + + def fn(x, y): + stream_ptr = torch.npu.current_stream()._as_parameter_ + lib.call_kernel( + stream_ptr, + ctypes.c_void_p(x.data_ptr()), + ctypes.c_void_p(y.data_ptr()), + ctypes.c_int32(x.size(0)), + ctypes.c_int32(x.size(1)), + ) + + return fn + + +def _lib_path(op_name, dtype): + return os.path.join(_DIR, f"{op_name}_{dtype}_lib.so") + + +def test_build_unary_kernels(compiled_lib): + assert os.path.exists(_lib_path(compiled_lib["op_name"], compiled_lib["dtype"])) + + +@pytest.mark.require_npu +@pytest.mark.parametrize("batch, n_cols", _SHAPE_PARAMS) +def test_unary_precision(compiled_lib, batch, n_cols): + import torch_npu # noqa: F401 + + torch.npu.set_device(_DEVICE) + op_name = compiled_lib["op_name"] + ref_fn = compiled_lib["ref_fn"] + torch_dtype = TORCH_DTYPES[compiled_lib["dtype"]] + + lib = ctypes.CDLL(compiled_lib["lib_path"]) + kernel = _lib_to_func_unary(lib) + + x = _make_input((batch, n_cols), _DEVICE, torch_dtype, op_name) + y = torch.empty(batch, n_cols, device=_DEVICE, dtype=torch_dtype) + kernel(x, y) + torch.npu.synchronize() + y_ref = ref_fn(x) + torch.npu.synchronize() + torch.testing.assert_close(y, y_ref, atol=2e-3, rtol=1e-3) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/npu/elementwise_unary_dynamic_multicore/unary_builder.py b/tests/npu/elementwise_unary_dynamic_multicore/unary_builder.py new file mode 100644 index 00000000..e652d33a --- /dev/null +++ b/tests/npu/elementwise_unary_dynamic_multicore/unary_builder.py @@ -0,0 +1,116 @@ +from ptodsl import pto, to_ir_module +from ptodsl import scalar as s + +const = s.const + +# 32 KB of UB +_TILE_SIZE_BYTES = 32 * 1024 +_DTYPE_BYTES = {"float32": 4, "float16": 2} + + +def meta_data(dtype="float32"): + pto_dtype = {"float32": pto.float32, "float16": pto.float16}[dtype] + elements_per_tile = _TILE_SIZE_BYTES // _DTYPE_BYTES[dtype] + ptr_type = pto.PtrType(pto_dtype) + index_dtype = pto.int32 + + tensor_type = pto.TensorType(rank=1, dtype=pto_dtype) + subtensor_type = pto.SubTensorType(shape=[1, elements_per_tile], dtype=pto_dtype) + + tile_cfg = pto.TileBufConfig() + tile_type = pto.TileBufType( + shape=[1, elements_per_tile], + valid_shape=[1, -1], + dtype=pto_dtype, + memory_space="VEC", + config=tile_cfg, + ) + + return { + "ptr_type": ptr_type, + "pto_dtype": pto_dtype, + "elements_per_tile": elements_per_tile, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_type": subtensor_type, + "tile_type": tile_type, + } + + +def build_unary_kernel(op_name, op_fn, dtype="float32"): + """ + Dynamic multicore unary elementwise kernel. + + Args: + x_ptr : dtype[batch * n_cols] input matrix, row-major + y_ptr : dtype[batch * n_cols] output matrix + batch_i32 : int32 number of rows + n_cols_i32 : int32 elements per row; must be <= elements_per_tile + + Semantics: + y[r, c] = op(x[r, c]) + """ + _meta_data = lambda: meta_data(dtype=dtype) + + @to_ir_module(meta_data=_meta_data) + def _kernel( + x_ptr: "ptr_type", + y_ptr: "ptr_type", + batch_i32: "index_dtype", + n_cols_i32: "index_dtype", + ) -> None: + c0 = const(0) + c1 = const(1) + c_tile = const(elements_per_tile) + + batch = s.index_cast(batch_i32) + n_cols = s.index_cast(n_cols_i32) + + with pto.vector_section(): + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + + vid = s.index_cast(cid * sub_bnum + sub_bid) + num_cores = s.index_cast(num_blocks * sub_bnum) + + rows_per_core = s.ceil_div(batch, num_cores) + row_start = vid * rows_per_core + row_end = s.min_u(row_start + rows_per_core, batch) + num_rows = row_end - row_start + + total_elems = batch * n_cols + tv_x = pto.as_tensor( + tensor_type, ptr=x_ptr, shape=[total_elems], strides=[c1] + ) + tv_y = pto.as_tensor( + tensor_type, ptr=y_ptr, shape=[total_elems], strides=[c1] + ) + + with pto.if_context(num_rows > c0): + tb_x = pto.alloc_tile(tile_type, valid_col=n_cols) + tb_y = pto.alloc_tile(tile_type, valid_col=n_cols) + + for row_i in pto.range(c0, num_rows, c1): + gm_offset = (row_start + row_i) * n_cols + + sv_x = pto.slice_view( + subtensor_type, + source=tv_x, + offsets=[gm_offset], + sizes=[n_cols], + ) + sv_y = pto.slice_view( + subtensor_type, + source=tv_y, + offsets=[gm_offset], + sizes=[n_cols], + ) + + pto.load(sv_x, tb_x) + op_fn(tb_x, tb_y) + pto.store(tb_y, sv_y) + + _ = op_name + return _kernel diff --git a/tests/npu/expand_dynamic_multicore/caller.py b/tests/npu/expand_dynamic_multicore/caller.py new file mode 100644 index 00000000..ca5f5090 --- /dev/null +++ b/tests/npu/expand_dynamic_multicore/caller.py @@ -0,0 +1,69 @@ +"""Generate caller.cpp for dynamic multicore col/row expand kernels. + +Usage: + python caller.py --mode colexpand|rowexpand|rowexpand_mul|rowexpand_sub|rowexpand_div +""" + +_FUSED_MODES = {"rowexpand_mul", "rowexpand_sub", "rowexpand_div"} + + +def generate_caller(mode, dtype): + ctype = "half" if dtype == "fp16" else "float" + if mode in _FUSED_MODES: + return f"""\ +#include "{mode}.cpp" + +extern "C" void call_{mode}( + uint32_t blockDim, + void *stream, + uint8_t *x, + uint8_t *y, + uint8_t *z, + uint32_t batch, + uint32_t n_cols) +{{ + _kernel<<>>( + reinterpret_cast<{ctype} *>(x), + reinterpret_cast<{ctype} *>(y), + reinterpret_cast<{ctype} *>(z), + static_cast(batch), + static_cast(n_cols)); +}} +""" + return f"""\ +#include "{mode}.cpp" + +extern "C" void call_{mode}( + uint32_t blockDim, + void *stream, + uint8_t *src, + uint8_t *dst, + uint32_t batch, + uint32_t n_cols) +{{ + _kernel<<>>( + reinterpret_cast<{ctype} *>(src), + reinterpret_cast<{ctype} *>(dst), + static_cast(batch), + static_cast(n_cols)); +}} +""" + + +if __name__ == "__main__": + import argparse + + MODES = [ + "colexpand", + "rowexpand", + "rowexpand_mul", + "rowexpand_sub", + "rowexpand_div", + ] + + parser = argparse.ArgumentParser() + parser.add_argument("--mode", choices=MODES, required=True) + parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp32") + args = parser.parse_args() + + print(generate_caller(args.mode, args.dtype)) diff --git a/tests/npu/expand_dynamic_multicore/compile.sh b/tests/npu/expand_dynamic_multicore/compile.sh new file mode 100755 index 00000000..54cbabfd --- /dev/null +++ b/tests/npu/expand_dynamic_multicore/compile.sh @@ -0,0 +1,44 @@ +#!/bin/bash +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +TMP=$(mktemp -d) +trap "rm -rf \"$TMP\"" EXIT + + +PTO_LIB_PATH=/sources/pto-isa +BISHENG_FLAGS=( + -I${PTO_LIB_PATH}/include + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong + -xcce -Xhost-start -Xhost-end + -mllvm -cce-aicore-stack-size=0x8000 + -mllvm -cce-aicore-function-stack-size=0x8000 + -mllvm -cce-aicore-record-overflow=true + -mllvm -cce-aicore-addr-transform + -mllvm -cce-aicore-dcci-insert-for-scalar=false + --npu-arch=dav-2201 -DMEMORY_BASE + -std=gnu++17 +) + +MODES=( + colexpand + rowexpand + rowexpand_mul + rowexpand_sub + rowexpand_div +) + +for MODE in "${MODES[@]}"; do + python "$SCRIPT_DIR/gen_ir.py" --mode "$MODE" > "$TMP/${MODE}.pto" + ptoas --enable-insert-sync "$TMP/${MODE}.pto" -o "$TMP/${MODE}.cpp" + + python "$SCRIPT_DIR/caller.py" --mode "$MODE" > "$TMP/${MODE}_caller.cpp" + + bisheng "${BISHENG_FLAGS[@]}" \ + "$TMP/${MODE}_caller.cpp" \ + -o "$SCRIPT_DIR/${MODE}_lib.so" + + echo "Built ${MODE}_lib.so successfully." +done diff --git a/tests/npu/expand_dynamic_multicore/expand_builder.py b/tests/npu/expand_dynamic_multicore/expand_builder.py new file mode 100644 index 00000000..aea958f7 --- /dev/null +++ b/tests/npu/expand_dynamic_multicore/expand_builder.py @@ -0,0 +1,355 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +_TILE_ROWS = 32 +_TILE_COLS = 32 + + +def meta_data_expand(dtype="fp32"): + pto_dtype = {"fp16": pto.float16, "fp32": pto.float32}[dtype] + ptr_type = pto.PtrType(pto_dtype) + index_dtype = pto.int32 + + tile_rows = _TILE_ROWS + tile_cols = _TILE_COLS + + tensor2d_type = pto.TensorType(rank=2, dtype=pto_dtype) + + # For col_expand: src slice is [1, tile_cols] (one row of the input vector) + subtensor_col_src = pto.SubTensorType(shape=[1, tile_cols], dtype=pto_dtype) + # For row_expand: src slice is [tile_rows, 1] (one column of the input vector) + subtensor_row_src = pto.SubTensorType(shape=[tile_rows, 1], dtype=pto_dtype) + # For loading/storing the 2D matrix + subtensor_dst = pto.SubTensorType(shape=[tile_rows, tile_cols], dtype=pto_dtype) + + tile_cfg = pto.TileBufConfig() + tile_type = pto.TileBufType( + shape=[tile_rows, tile_cols], + valid_shape=[-1, -1], + dtype=pto_dtype, + memory_space="VEC", + config=tile_cfg, + ) + return { + "ptr_type": ptr_type, + "pto_dtype": pto_dtype, + "index_dtype": index_dtype, + "tensor2d_type": tensor2d_type, + "subtensor_col_src": subtensor_col_src, + "subtensor_row_src": subtensor_row_src, + "subtensor_dst": subtensor_dst, + "tile_type": tile_type, + "tile_rows": tile_rows, + "tile_cols": tile_cols, + } + + +def build_col_expand(dtype="fp32"): + """ + Column-wise broadcast: replicate each element of y[j] across all rows. + + Semantics: + X[i, j] = y[j] + """ + _meta_data = lambda: meta_data_expand(dtype=dtype) + + @to_ir_module(meta_data=_meta_data) + def _kernel( + y_ptr: "ptr_type", + x_ptr: "ptr_type", + batch_i32: "index_dtype", + n_cols_i32: "index_dtype", + ) -> None: + c0 = const(0) + c1 = const(1) + c_tile_rows = const(tile_rows) + c_tile_cols = const(tile_cols) + + batch = s.index_cast(batch_i32) + n_cols = s.index_cast(n_cols_i32) + + with pto.vector_section(): + bid = s.index_cast(pto.get_block_idx()) + num_cores = s.index_cast(pto.get_block_num()) + + cols_per_core = s.ceil_div(n_cols, num_cores) + col_start = bid * cols_per_core + col_end = s.min_u(col_start + cols_per_core, n_cols) + + # y[n_cols] represented as 2D [1, n_cols] for uniform slice_view usage + tv_y = pto.as_tensor( + tensor2d_type, + ptr=y_ptr, + shape=[c1, n_cols], + strides=[n_cols, c1], + ) + tv_x = pto.as_tensor( + tensor2d_type, + ptr=x_ptr, + shape=[batch, n_cols], + strides=[n_cols, c1], + ) + + for col in pto.range(col_start, col_end, c_tile_cols): + cols_this = s.min_u(c_tile_cols, col_end - col) + + # Load one row of y into the src tile (valid_row=1) + tb_src = pto.alloc_tile(tile_type, valid_row=c1, valid_col=cols_this) + sv_y = pto.slice_view( + subtensor_col_src, + source=tv_y, + offsets=[c0, col], + sizes=[c1, cols_this], + ) + pto.load(sv_y, tb_src) + + for row in pto.range(c0, batch, c_tile_rows): + rows_this = s.min_u(c_tile_rows, batch - row) + + tb_dst = pto.alloc_tile( + tile_type, valid_row=rows_this, valid_col=cols_this + ) + tile.col_expand(tb_src, tb_dst) + + sv_x = pto.slice_view( + subtensor_dst, + source=tv_x, + offsets=[row, col], + sizes=[rows_this, cols_this], + ) + pto.store(tb_dst, sv_x) + + return _kernel + + +def build_row_expand(dtype="fp32"): + """ + Row-wise broadcast: replicate each element of x[i] across all columns. + + Semantics: + Y[i,j] = x[i] + """ + _meta_data = lambda: meta_data_expand(dtype=dtype) + + @to_ir_module(meta_data=_meta_data) + def _kernel( + x_ptr: "ptr_type", + y_ptr: "ptr_type", + batch_i32: "index_dtype", + n_cols_i32: "index_dtype", + ) -> None: + c0 = const(0) + c1 = const(1) + c_tile_rows = const(tile_rows) + c_tile_cols = const(tile_cols) + + batch = s.index_cast(batch_i32) + n_cols = s.index_cast(n_cols_i32) + + with pto.vector_section(): + bid = s.index_cast(pto.get_block_idx()) + num_cores = s.index_cast(pto.get_block_num()) + + rows_per_core = s.ceil_div(batch, num_cores) + row_start = bid * rows_per_core + row_end = s.min_u(row_start + rows_per_core, batch) + + # x[batch] represented as 2D [batch, 1] for uniform slice_view usage + tv_x = pto.as_tensor( + tensor2d_type, + ptr=x_ptr, + shape=[batch, c1], + strides=[c1, c1], + ) + tv_y = pto.as_tensor( + tensor2d_type, + ptr=y_ptr, + shape=[batch, n_cols], + strides=[n_cols, c1], + ) + + for row in pto.range(row_start, row_end, c_tile_rows): + rows_this = s.min_u(c_tile_rows, row_end - row) + + # Load one column of x into the src tile (valid_col=1) + tb_src = pto.alloc_tile(tile_type, valid_row=rows_this, valid_col=c1) + sv_x = pto.slice_view( + subtensor_row_src, + source=tv_x, + offsets=[row, c0], + sizes=[rows_this, c1], + ) + pto.load(sv_x, tb_src) + + for col in pto.range(c0, n_cols, c_tile_cols): + cols_this = s.min_u(c_tile_cols, n_cols - col) + + tb_dst = pto.alloc_tile( + tile_type, valid_row=rows_this, valid_col=cols_this + ) + tile.row_expand(tb_src, tb_dst) + + sv_y = pto.slice_view( + subtensor_dst, + source=tv_y, + offsets=[row, col], + sizes=[rows_this, cols_this], + ) + pto.store(tb_dst, sv_y) + + return _kernel + + +# Fused row-expand ops: dst[i,j] = src0[i,j] op src1[0,i] +# src1 is a row-vector tile (valid_row=1, valid_col=rows_this) +# so src1[0,i] = x[row+i] per the hardware op convention. +_ROW_EXPAND_FUSED_OPS = { + "expand_mul": tile.row_expand_mul, + "expand_sub": tile.row_expand_sub, + "expand_div": tile.row_expand_div, +} + + +def _build_row_expand_fused(kind, dtype="fp32"): + """ + Fused row-expand: apply element-wise op between Y[i,j] and x[i]. + + Semantics: + expand_mul: Y[i,j] *= x[i] + expand_sub: Y[i,j] -= x[i] + expand_div: Y[i,j] /= x[i] + + src1 tile is a scalar [1, 1]: src1[0,0] = x[row], one row at a time. + """ + row_op = _ROW_EXPAND_FUSED_OPS[kind] + _meta_data = lambda: meta_data_expand(dtype=dtype) + + @to_ir_module(meta_data=_meta_data) + def _kernel( + x_ptr: "ptr_type", + y_ptr: "ptr_type", + z_ptr: "ptr_type", + batch_i32: "index_dtype", + n_cols_i32: "index_dtype", + ) -> None: + c0 = const(0) + c1 = const(1) + c_tile_cols = const(tile_cols) + + batch = s.index_cast(batch_i32) + n_cols = s.index_cast(n_cols_i32) + + with pto.vector_section(): + bid = s.index_cast(pto.get_block_idx()) + num_cores = s.index_cast(pto.get_block_num()) + + rows_per_core = s.ceil_div(batch, num_cores) + row_start = bid * rows_per_core + row_end = s.min_u(row_start + rows_per_core, batch) + + # y[batch, n_cols] - input matrix (src0) + tv_y = pto.as_tensor( + tensor2d_type, + ptr=y_ptr, + shape=[batch, n_cols], + strides=[n_cols, c1], + ) + # z[batch, n_cols] - output matrix (dst) + tv_z = pto.as_tensor( + tensor2d_type, + ptr=z_ptr, + shape=[batch, n_cols], + strides=[n_cols, c1], + ) + # x as column vector [batch, 1]: x[row] stored at tv_x[row, 0] + tv_x = pto.as_tensor( + tensor2d_type, + ptr=x_ptr, + shape=[batch, c1], + strides=[c1, c1], + ) + + # Process one row at a time so tb_src1 always has rows_this=1, + # making src1[0,0] = x[row] unambiguous for both row/col conventions. + for row in pto.range(row_start, row_end, c1): + # Load scalar x[row] into a [1, 1] tile: src1[0,0] = x[row] + tb_src1 = pto.alloc_tile(tile_type, valid_row=c1, valid_col=c1) + sv_x = pto.slice_view( + subtensor_row_src, + source=tv_x, + offsets=[row, c0], + sizes=[c1, c1], + ) + pto.load(sv_x, tb_src1) + + for col in pto.range(c0, n_cols, c_tile_cols): + cols_this = s.min_u(c_tile_cols, n_cols - col) + + sv_y = pto.slice_view( + subtensor_dst, + source=tv_y, + offsets=[row, col], + sizes=[c1, cols_this], + ) + sv_z = pto.slice_view( + subtensor_dst, + source=tv_z, + offsets=[row, col], + sizes=[c1, cols_this], + ) + + # src0 = one row of Y, src1 = scalar x[row], dst = one row of Z + tb_src0 = pto.alloc_tile( + tile_type, valid_row=c1, valid_col=cols_this + ) + pto.load(sv_y, tb_src0) + + tb_dst = pto.alloc_tile( + tile_type, valid_row=c1, valid_col=cols_this + ) + row_op(tb_src0, tb_src1, tb_dst) + + pto.store(tb_dst, sv_z) + + return _kernel + + +def build_row_expand_mul(dtype="fp32"): + return _build_row_expand_fused("expand_mul", dtype=dtype) + + +def build_row_expand_sub(dtype="fp32"): + return _build_row_expand_fused("expand_sub", dtype=dtype) + + +def build_row_expand_div(dtype="fp32"): + return _build_row_expand_fused("expand_div", dtype=dtype) + + +if __name__ == "__main__": + import argparse + + _MODES = [ + "colexpand", + "rowexpand", + "rowexpand_mul", + "rowexpand_sub", + "rowexpand_div", + ] + + parser = argparse.ArgumentParser() + parser.add_argument("--mode", choices=_MODES, default="colexpand") + parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp32") + args = parser.parse_args() + + builders = { + "colexpand": build_col_expand, + "rowexpand": build_row_expand, + "rowexpand_mul": build_row_expand_mul, + "rowexpand_sub": build_row_expand_sub, + "rowexpand_div": build_row_expand_div, + } + + print(builders[args.mode](dtype=args.dtype)) diff --git a/tests/npu/expand_dynamic_multicore/gen_ir.py b/tests/npu/expand_dynamic_multicore/gen_ir.py new file mode 100644 index 00000000..649cef89 --- /dev/null +++ b/tests/npu/expand_dynamic_multicore/gen_ir.py @@ -0,0 +1,39 @@ +"""Print MLIR IR for dynamic multicore col/row expand kernels. + +Usage: + python gen_ir.py --mode colexpand|rowexpand|rowexpand_mul|rowexpand_sub|rowexpand_div +""" + +import argparse +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from expand_builder import ( + build_col_expand, + build_row_expand, + build_row_expand_div, + build_row_expand_mul, + build_row_expand_sub, +) + +_BUILDERS = { + "colexpand": build_col_expand, + "rowexpand": build_row_expand, + "rowexpand_mul": build_row_expand_mul, + "rowexpand_sub": build_row_expand_sub, + "rowexpand_div": build_row_expand_div, +} + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + choices=list(_BUILDERS.keys()), + default="colexpand", + ) + parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp32") + args = parser.parse_args() + + print(_BUILDERS[args.mode](dtype=args.dtype)) diff --git a/tests/npu/expand_dynamic_multicore/test_expand.py b/tests/npu/expand_dynamic_multicore/test_expand.py new file mode 100644 index 00000000..ac3a4ec9 --- /dev/null +++ b/tests/npu/expand_dynamic_multicore/test_expand.py @@ -0,0 +1,169 @@ +import ctypes +import os +import subprocess + +import pytest +import torch + +from ptodsl.test_util import get_test_device + +torch.manual_seed(0) + +_DIR = os.path.dirname(os.path.abspath(__file__)) +_DEVICE = get_test_device() +_BLOCK_DIM = 24 + +_KERNELS = [ + "colexpand", + "rowexpand", + "rowexpand_mul", + "rowexpand_sub", + "rowexpand_div", +] + +_LIB_PATHS = {name: os.path.join(_DIR, f"{name}_lib.so") for name in _KERNELS} + +_SHAPES = [ + (1, 1), + (7, 7), + (15, 17), + (31, 33), + (33, 31), + (64, 32), + (32, 64), + (65, 64), + (29, 257), + (127, 129), +] + +_SHAPE_IDS = [f"batch{batch}-cols{n_cols}" for batch, n_cols in _SHAPES] + + +@pytest.fixture(scope="session") +def compiled_kernels(): + subprocess.check_call(["bash", os.path.join(_DIR, "compile.sh")], cwd=_DIR) + yield + for path in _LIB_PATHS.values(): + if os.path.exists(path): + os.remove(path) + + +_FUSED_KERNELS = {"rowexpand_mul", "rowexpand_sub", "rowexpand_div"} + + +def _load_kernel(name): + lib = ctypes.CDLL(_LIB_PATHS[name]) + fn = getattr(lib, f"call_{name}") + if name in _FUSED_KERNELS: + # fused: (blockDim, stream, x, y, z, batch, n_cols) + fn.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ] + else: + fn.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ] + fn.restype = None + return fn + + +def _make_inputs(name, batch, n_cols, device): + if name == "colexpand": + src = torch.randn(n_cols, device=device, dtype=torch.float32) + dst = torch.zeros((batch, n_cols), device=device, dtype=torch.float32) + return src, dst, None + if name == "rowexpand": + src = torch.randn(batch, device=device, dtype=torch.float32) + dst = torch.zeros((batch, n_cols), device=device, dtype=torch.float32) + return src, dst, None + if name == "rowexpand_div": + # avoid division by zero: keep x away from 0 + x = torch.empty(batch, device=device, dtype=torch.float32).uniform_(0.5, 1.5) + y = torch.randn(batch, n_cols, device=device, dtype=torch.float32) + z = torch.zeros((batch, n_cols), device=device, dtype=torch.float32) + return x, y, z + # rowexpand_mul, rowexpand_sub + x = torch.randn(batch, device=device, dtype=torch.float32) + y = torch.randn(batch, n_cols, device=device, dtype=torch.float32) + z = torch.zeros((batch, n_cols), device=device, dtype=torch.float32) + return x, y, z + + +def _reference(name, x, y): + if name == "colexpand": + return x.float().unsqueeze(0).expand_as(y) + if name == "rowexpand": + return x.float().unsqueeze(1).expand_as(y) + if name == "rowexpand_mul": + return y.float() * x.float().unsqueeze(1) + if name == "rowexpand_sub": + return y.float() - x.float().unsqueeze(1) + if name == "rowexpand_div": + return y.float() / x.float().unsqueeze(1) + raise ValueError(f"Unknown kernel: {name}") + + +def _tolerances(name): + if name in {"colexpand", "rowexpand"}: + return {"atol": 0, "rtol": 0} + return {"atol": 1e-4, "rtol": 1e-4} + + +@pytest.mark.parametrize("name", _KERNELS) +def test_build_kernel(compiled_kernels, name): + assert os.path.exists(_LIB_PATHS[name]) + + +@pytest.mark.require_npu +@pytest.mark.parametrize("name", _KERNELS) +@pytest.mark.parametrize("batch, n_cols", _SHAPES, ids=_SHAPE_IDS) +def test_kernel_precision(compiled_kernels, name, batch, n_cols): + import torch_npu # noqa: F401 + + torch.npu.set_device(_DEVICE) + + fn = _load_kernel(name) + + x, y, z = _make_inputs(name, batch, n_cols, _DEVICE) + dst_ref = _reference(name, x, y) + + stream_ptr = torch.npu.current_stream()._as_parameter_ + if name in _FUSED_KERNELS: + fn( + ctypes.c_uint32(_BLOCK_DIM), + stream_ptr, + ctypes.c_void_p(x.data_ptr()), + ctypes.c_void_p(y.data_ptr()), + ctypes.c_void_p(z.data_ptr()), + ctypes.c_uint32(batch), + ctypes.c_uint32(n_cols), + ) + out = z + else: + fn( + ctypes.c_uint32(_BLOCK_DIM), + stream_ptr, + ctypes.c_void_p(x.data_ptr()), + ctypes.c_void_p(y.data_ptr()), + ctypes.c_uint32(batch), + ctypes.c_uint32(n_cols), + ) + out = y + torch.npu.synchronize() + + torch.testing.assert_close(out, dst_ref, **_tolerances(name)) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/npu/gather_dynamic_multicore/builder.py b/tests/npu/gather_dynamic_multicore/builder.py index b5bbbc29..1851ff54 100644 --- a/tests/npu/gather_dynamic_multicore/builder.py +++ b/tests/npu/gather_dynamic_multicore/builder.py @@ -1,7 +1,7 @@ -from ptodsl import to_ir_module -import ptodsl.language as pto +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s -const = pto.const +const = s.const DTYPES = { "float32": lambda: pto.float32, @@ -88,18 +88,18 @@ def _kernel( c1 = const(1) c_tile = const(tile_length) - total_elements = pto.index_cast(argB * argN) # B * N + total_elements = s.index_cast(argB * argN) # B * N cid = pto.get_block_idx() sub_bid = pto.get_subblock_idx() sub_bnum = pto.get_subblock_num() vid = cid * sub_bnum + sub_bid num_blocks = pto.get_block_num() - vid_idx = pto.index_cast(vid) - num_cores = pto.index_cast(num_blocks) + vid_idx = s.index_cast(vid) + num_cores = s.index_cast(num_blocks) - num_tiles_global = pto.ceil_div(total_elements, c_tile) - num_tiles_per_core = pto.ceil_div(num_tiles_global, num_cores) + num_tiles_global = s.ceil_div(total_elements, c_tile) + num_tiles_per_core = s.ceil_div(num_tiles_global, num_cores) tile_offset_this_core = vid_idx * num_tiles_per_core with pto.vector_section(): @@ -123,13 +123,13 @@ def _kernel( tiles_end_this_core = tile_offset_this_core + num_tiles_per_core need_truncate = tiles_end_this_core > num_tiles_global remaining_tiles = num_tiles_global - tile_offset_this_core - tiles_to_process = pto.select( + tiles_to_process = s.select( need_truncate, remaining_tiles, num_tiles_per_core ) elements_to_process = tiles_to_process * c_tile with pto.if_context(elements_to_process > c0): - for i in pto.for_range(c0, tiles_to_process, c1): + for i in pto.range(c0, tiles_to_process, c1): tile_offset_global = i + tile_offset_this_core offset_global = tile_offset_global * c_tile @@ -150,9 +150,9 @@ def _kernel( pto.load(sv1, tb_idx) # gather within tile by indices - pto.gather(tb_src, tb_tmp, tb_idx) + tile.gather(tb_src, tb_tmp, tb_idx) - pto.gather(tb_tmp, tb_out, mask_pattern=mask_pattern) + tile.gather(tb_tmp, tb_out, mask_pattern=mask_pattern) sv2 = pto.slice_view( subtensor_type, diff --git a/tests/npu/gather_dynamic_multicore/test_gather_dynamic.py b/tests/npu/gather_dynamic_multicore/test_gather_dynamic.py index 507c7259..1be4f4f5 100644 --- a/tests/npu/gather_dynamic_multicore/test_gather_dynamic.py +++ b/tests/npu/gather_dynamic_multicore/test_gather_dynamic.py @@ -23,7 +23,7 @@ ("float16", "P0101"), ("float16", "P1111"), ("float16", "P0001"), - ("float16", "P1010") + ("float16", "P1010"), ] # Runtime shapes (B, N). N must be a multiple of 32. @@ -168,7 +168,9 @@ def test_build_gather(compiled_lib): @pytest.mark.require_npu -@pytest.mark.xfail(reason="Known unsolved issues of indeterministic output values", strict=False) +@pytest.mark.xfail( + reason="Known unsolved issues of indeterministic output values", strict=False +) @pytest.mark.parametrize("B, N", _SHAPE_PARAMS) def test_gather_dynamic(compiled_lib, B, N): import torch_npu @@ -195,9 +197,7 @@ def test_gather_dynamic(compiled_lib, B, N): ref = _gather_ref_blocked(src, indices, mask_pattern, num_blocks=NUM_BLOCKS) - torch.testing.assert_close( - out, ref, msg=f"shape=({B},{N}), mask={mask_pattern}" - ) + torch.testing.assert_close(out, ref, msg=f"shape=({B},{N}), mask={mask_pattern}") if __name__ == "__main__": diff --git a/tests/npu/gather_static_singlecore/builder.py b/tests/npu/gather_static_singlecore/builder.py index 417a76e7..2f7357c9 100644 --- a/tests/npu/gather_static_singlecore/builder.py +++ b/tests/npu/gather_static_singlecore/builder.py @@ -1,7 +1,7 @@ -from ptodsl import to_ir_module -import ptodsl.language as pto +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s -const = pto.const +const = s.const _DTYPE_MAP = { "float32": lambda: pto.float32, @@ -90,8 +90,8 @@ def _kernel( pto.load(sv0, tb0) pto.load(sv1, tb1) - pto.gather(tb0, tb2, tb1) # index-gather: tb2[i,j] = tb0[tb1[i,j]] - pto.gather( + tile.gather(tb0, tb2, tb1) # index-gather: tb2[i,j] = tb0[tb1[i,j]] + tile.gather( tb2, tb3, mask_pattern=mask_pattern ) # mask-gather with configurable pattern diff --git a/tests/npu/mrgsort_dynamic_multicore/builder.py b/tests/npu/mrgsort_dynamic_multicore/builder.py new file mode 100644 index 00000000..3e054ce4 --- /dev/null +++ b/tests/npu/mrgsort_dynamic_multicore/builder.py @@ -0,0 +1,160 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +DTYPES = { + "float32": lambda: pto.float32, + "float16": lambda: pto.float16, +} + +# TMRGSORT's blockLen parameter is in float32-word units: +# hw_block_len = block_len * (sizeof(float) / sizeof(T)) +_TYPE_COEF = {"float32": 1, "float16": 2} + + +def meta_data(dtype=None, tile_length=1024): + if dtype is None: + dtype = "float32" + if isinstance(dtype, str): + dtype = DTYPES[dtype]() + + index_dtype = pto.int32 + ptr_type = pto.PtrType(dtype) + # 2D tensor view: shape [num_tiles, tile_length], matching the expand_builder pattern. + tensor_type = pto.TensorType(rank=2, dtype=dtype) + subtensor_type = pto.SubTensorType(shape=[1, tile_length], dtype=dtype) + tile_cfg = pto.TileBufConfig() + tile_type = pto.TileBufType( + shape=[1, tile_length], + valid_shape=[1, tile_length], + dtype=dtype, + memory_space="VEC", + config=tile_cfg, + ) + return { + "ptr_type": ptr_type, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_type": subtensor_type, + "tile_type": tile_type, + "tile_length": tile_length, + } + + +def build_mrgsort_kernel( + fn_name="vec_mrgsort_1d_dynamic_float32", + dtype="float32", + tile_length=1024, + block_len=32, +): + """Build a 1D dynamic multicore merge-sort kernel. + + Each tile of tile_length elements is treated as containing + tile_length // block_len pre-sorted sub-lists of block_len elements. + TMRGSORT merges groups of 4 sub-lists (block_len*4 elements) independently; + repeatTimes = tile_length // (block_len * 4) such groups per tile. + + The hardware blockLen passed to TMRGSORT is scaled by TYPE_COEF + (= sizeof(float) / sizeof(T)) per the instruction's float32-word semantics: + hw_block_len = block_len * TYPE_COEF + + Constraints (enforced by TMRGSORT): + - hw_block_len must be a multiple of 64 + - tile_length must be a multiple of hw_block_len * 4 + - repeatTimes = tile_length / (hw_block_len * 4) must be in [1, 255] + """ + dtype_str = dtype if isinstance(dtype, str) else "float32" + hw_block_len = block_len * _TYPE_COEF.get(dtype_str, 1) + _meta_data = lambda: meta_data(dtype=dtype, tile_length=tile_length) + + def _kernel( + arg0: "ptr_type", # src: input with sorted sub-lists + arg1: "ptr_type", # out: merged sorted output + argN: "index_dtype", # total number of elements (multiple of tile_length) + ) -> None: + assert tile_length % (hw_block_len * 4) == 0 + assert hw_block_len % 64 == 0 + c0 = const(0) + c1 = const(1) + c_tile = const(tile_length) + + total_elements = s.index_cast(argN) + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + vid = cid * sub_bnum + sub_bid + num_blocks = pto.get_block_num() + + vid_idx = s.index_cast(vid) + # Total virtual cores = num_blocks * subblock_num (matches add_dynamic_multicore). + num_cores = s.index_cast(num_blocks * sub_bnum) + + num_tiles_global = s.ceil_div(total_elements, c_tile) + num_tiles_per_core = s.ceil_div(num_tiles_global, num_cores) + tile_offset_this_core = vid_idx * num_tiles_per_core + + with pto.vector_section(): + # 2D tensor views: shape=[num_tiles, tile_length], strides=[tile_length, 1]. + # Mirrors the expand_builder layout where rows are tiles and columns are elements. + tv0 = pto.as_tensor( + tensor_type, + ptr=arg0, + shape=[num_tiles_global, c_tile], + strides=[c_tile, c1], + ) + tv1 = pto.as_tensor( + tensor_type, + ptr=arg1, + shape=[num_tiles_global, c_tile], + strides=[c_tile, c1], + ) + + tb_src = pto.alloc_tile(tile_type) + tb_tmp = pto.alloc_tile(tile_type) + tb_dst = pto.alloc_tile(tile_type) + + with pto.if_context(tile_offset_this_core < num_tiles_global): + tiles_end_this_core = tile_offset_this_core + num_tiles_per_core + need_truncate = tiles_end_this_core > num_tiles_global + remaining_tiles = num_tiles_global - tile_offset_this_core + tiles_to_process = s.select( + need_truncate, remaining_tiles, num_tiles_per_core + ) + + with pto.if_context(tiles_to_process > c0): + for i in pto.range(c0, tiles_to_process, c1): + tile_idx = i + tile_offset_this_core + + sv0 = pto.slice_view( + subtensor_type, + source=tv0, + offsets=[tile_idx, c0], + sizes=[c1, c_tile], + ) + + pto.load(sv0, tb_src) + # Multi-pass merge sort: blockLen doubles each pass (reference + # MrgsortSingleRow pattern). This loop is unrolled at code-gen + # time since all bounds are Python-level constants. + cur_block_len = hw_block_len + while cur_block_len * 4 <= tile_length: + tile.mrgsort(tb_src, tb_tmp, const(cur_block_len)) + tile.mov(tb_tmp, tb_src) + cur_block_len *= 4 + tile.mov(tb_src, tb_dst) + + sv1 = pto.slice_view( + subtensor_type, + source=tv1, + offsets=[tile_idx, c0], + sizes=[c1, c_tile], + ) + pto.store(tb_dst, sv1) + + _kernel.__name__ = fn_name + return to_ir_module(meta_data=_meta_data)(_kernel) + + +if __name__ == "__main__": + print(build_mrgsort_kernel(dtype="float32", tile_length=1024, block_len=64)) diff --git a/tests/npu/mrgsort_dynamic_multicore/caller.py b/tests/npu/mrgsort_dynamic_multicore/caller.py new file mode 100644 index 00000000..eadf4960 --- /dev/null +++ b/tests/npu/mrgsort_dynamic_multicore/caller.py @@ -0,0 +1,37 @@ +"""Generate caller.cpp for the dynamic multicore merge-sort kernel.""" + +import sys + +_DTYPE_TO_CTYPE = { + "float32": "float", + "float16": "half", +} + +_BLOCK_DIM = 24 + + +def fn_name(dtype): + return f"vec_mrgsort_1d_dynamic_{dtype}" + + +def generate_caller(dtype): + ctype = _DTYPE_TO_CTYPE[dtype] + fn = fn_name(dtype) + return f"""\ +#include "{fn}.cpp" + +extern "C" void call_{fn}( + void *stream, uint8_t *src, uint8_t *out, int32_t N) +{{ + {fn}<<<{_BLOCK_DIM}, nullptr, stream>>>( + ({ctype} *)src, ({ctype} *)out, (int32_t)N); +}} +""" + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python caller.py ", file=sys.stderr) + sys.exit(1) + dtype = sys.argv[1] + print(generate_caller(dtype)) diff --git a/tests/npu/mrgsort_dynamic_multicore/compile.sh b/tests/npu/mrgsort_dynamic_multicore/compile.sh new file mode 100755 index 00000000..ccb6aacb --- /dev/null +++ b/tests/npu/mrgsort_dynamic_multicore/compile.sh @@ -0,0 +1,34 @@ +#!/bin/bash +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +DTYPE=${1:?Usage: compile.sh } + +FN_NAME="vec_mrgsort_1d_dynamic_${DTYPE}" + +TMP=$(mktemp -d) +trap "rm -rf $TMP" EXIT + +python "$SCRIPT_DIR/gen_ir.py" "$DTYPE" > "$TMP/${FN_NAME}.pto" +ptoas --enable-insert-sync "$TMP/${FN_NAME}.pto" -o "$TMP/${FN_NAME}.cpp" + +python "$SCRIPT_DIR/caller.py" "$DTYPE" > "$TMP/caller.cpp" + +PTO_LIB_PATH=/sources/pto-isa +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + "$TMP/caller.cpp" \ + -o "$SCRIPT_DIR/${FN_NAME}_lib.so" + +echo "Built ${FN_NAME}_lib.so successfully." diff --git a/tests/npu/mrgsort_dynamic_multicore/gen_ir.py b/tests/npu/mrgsort_dynamic_multicore/gen_ir.py new file mode 100644 index 00000000..1326395d --- /dev/null +++ b/tests/npu/mrgsort_dynamic_multicore/gen_ir.py @@ -0,0 +1,32 @@ +"""Print MLIR IR for the dynamic multicore merge-sort kernel. + +Usage: python gen_ir.py +""" + +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from builder import build_mrgsort_kernel + +TILE_LENGTH = 1024 +BLOCK_LEN = 64 + + +def fn_name(dtype): + return f"vec_mrgsort_1d_dynamic_{dtype}" + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python gen_ir.py ", file=sys.stderr) + sys.exit(1) + dtype = sys.argv[1] + module = build_mrgsort_kernel( + fn_name=fn_name(dtype), + dtype=dtype, + tile_length=TILE_LENGTH, + block_len=BLOCK_LEN, + ) + print(module) diff --git a/tests/npu/mrgsort_dynamic_multicore/test_mrgsort.py b/tests/npu/mrgsort_dynamic_multicore/test_mrgsort.py new file mode 100644 index 00000000..3011e830 --- /dev/null +++ b/tests/npu/mrgsort_dynamic_multicore/test_mrgsort.py @@ -0,0 +1,194 @@ +import os +import ctypes +import subprocess + +import pytest +import torch +from ptodsl.test_util import get_test_device + +torch.manual_seed(0) + +_DIR = os.path.dirname(os.path.abspath(__file__)) +_DEVICE = get_test_device() + +# TMRGSORT single-list constraints (in terms of hw_block_len = BLOCK_LEN * TYPE_COEF): +# hw_block_len % 64 == 0 +# tile_length % (hw_block_len * 4) == 0 +# 1 <= tile_length // (hw_block_len * 4) <= 255 +# TYPE_COEF = sizeof(float) / sizeof(T): 1 for float32, 2 for float16. +TILE_LENGTH = 1024 +BLOCK_LEN = 64 +TYPE_COEFS = {"float32": 1, "float16": 2} + +# TMRGSORT operates on (float32, uint32) interleaved pairs for float16 tiles, +# not on plain float16 values. Sorting plain float16 with TMRGSORT requires +# a TSORT32 pre-pass (to produce the pair format) and a TGATHER post-pass +# (to extract values). The tests here cover the plain-value sort path only, +# which is supported for float32. +DTYPES = ["float32"] +SIZES = [1024, 2048, 3072, 4096, 8192, 16384] + +TORCH_DTYPES = { + "float32": torch.float32, + "float16": torch.float16, +} + +_DTYPE_PARAMS = [pytest.param(dtype, id=dtype) for dtype in DTYPES] +_SIZE_PARAMS = [pytest.param(N, id=f"N{N}") for N in SIZES] + + +def _fn_name(dtype: str) -> str: + return f"vec_mrgsort_1d_dynamic_{dtype}" + + +def _lib_path(dtype: str) -> str: + return os.path.join(_DIR, f"{_fn_name(dtype)}_lib.so") + + +def _ctypes_ptr(tensor: torch.Tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +@pytest.fixture(scope="session", params=_DTYPE_PARAMS) +def compiled_lib(request): + dtype = request.param + subprocess.check_call( + ["bash", os.path.join(_DIR, "compile.sh"), dtype], + cwd=_DIR, + ) + yield {"dtype": dtype} + libp = _lib_path(dtype) + if os.path.exists(libp): + os.remove(libp) + + +def _check_preconditions(N: int, block_len: int, tile_length: int): + assert block_len % 64 == 0, f"block_len must be multiple of 64, got {block_len}" + assert tile_length % (block_len * 4) == 0, ( + f"tile_length must be multiple of block_len*4, got " + f"tile_length={tile_length}, block_len={block_len}" + ) + repeat_times = tile_length // (block_len * 4) + assert ( + 1 <= repeat_times <= 255 + ), f"repeat_times must be in [1, 255], got {repeat_times}" + assert N % tile_length == 0, f"N must be a multiple of tile_length, got N={N}" + + +def _make_sorted_sublists( + N: int, + block_len: int, + device, + torch_dtype: torch.dtype, +) -> torch.Tensor: + """ + Create N values split into sorted descending sublists of length block_len. + """ + assert N % block_len == 0 + data = torch.rand(N, dtype=torch.float32) + data = data.view(-1, block_len) + data = torch.sort(data, dim=1, descending=True).values + return data.reshape(-1).to(dtype=torch_dtype, device=device) + + +def _load_fn(dtype: str): + lib = ctypes.CDLL(_lib_path(dtype)) + fn = getattr(lib, f"call_{_fn_name(dtype)}") + fn.argtypes = [ + ctypes.c_void_p, # stream + ctypes.c_void_p, # src + ctypes.c_void_p, # out + ctypes.c_int32, # N + ] + fn.restype = None + return fn + + +def _run_kernel(fn, stream_ptr, src: torch.Tensor, N: int) -> torch.Tensor: + import torch_npu + + out = torch.empty_like(src) + torch.npu.synchronize() + fn(stream_ptr, _ctypes_ptr(src), _ctypes_ptr(out), ctypes.c_int32(N)) + torch.npu.synchronize() + return out + + +def _sort_tiles(x: torch.Tensor, tile_length: int) -> torch.Tensor: + """ + Sort each tile independently descending. + The multi-pass kernel fully sorts each tile (float32) or sorts within + hw_block_len*4 sub-segments (float16); either way, sorted(out_tile) must + equal sorted(src_tile) for a correct permutation sort. + """ + x = x.cpu().float().reshape(-1, tile_length) + x = torch.sort(x, dim=1, descending=True).values + return x.reshape(-1) + + +def test_build_mrgsort(compiled_lib): + dtype = compiled_lib["dtype"] + assert os.path.exists(_lib_path(dtype)) + + +@pytest.mark.require_npu +@pytest.mark.parametrize("N", _SIZE_PARAMS) +def test_mrgsort_equal_after_canonicalization(compiled_lib, N): + """ + Compare exact equality after canonicalizing each hw_block_len*4 segment. + This is the right equality check for the current single-list TMRGSORT behavior. + """ + import torch_npu + + dtype = compiled_lib["dtype"] + hw_block_len = BLOCK_LEN * TYPE_COEFS[dtype] + _check_preconditions(N, hw_block_len, TILE_LENGTH) + torch.npu.set_device(_DEVICE) + + torch_dtype = TORCH_DTYPES[dtype] + fn = _load_fn(dtype) + stream_ptr = torch.npu.current_stream()._as_parameter_ + + src = _make_sorted_sublists(N, hw_block_len, _DEVICE, torch_dtype) + out = _run_kernel(fn, stream_ptr, src, N) + + ref = _sort_tiles(src, TILE_LENGTH).to(torch_dtype) + got = _sort_tiles(out, TILE_LENGTH).to(torch_dtype) + + torch.testing.assert_close( + got.cpu(), + ref.cpu(), + msg="sorted output does not match sorted reference", + ) + + +@pytest.mark.require_npu +@pytest.mark.parametrize("N", [1024]) +def test_mrgsort_deterministic(compiled_lib, N): + """ + Same input must produce identical output on two runs. + """ + import torch_npu + + dtype = compiled_lib["dtype"] + hw_block_len = BLOCK_LEN * TYPE_COEFS[dtype] + _check_preconditions(N, hw_block_len, TILE_LENGTH) + torch.npu.set_device(_DEVICE) + + torch_dtype = TORCH_DTYPES[dtype] + fn = _load_fn(dtype) + stream_ptr = torch.npu.current_stream()._as_parameter_ + + src = _make_sorted_sublists(N, hw_block_len, _DEVICE, torch_dtype) + out1 = _run_kernel(fn, stream_ptr, src, N) + out2 = _run_kernel(fn, stream_ptr, src, N) + + torch.testing.assert_close( + out1.cpu(), + out2.cpu(), + msg="kernel output is not deterministic", + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/npu/reduce_dynamic_multicore/caller.py b/tests/npu/reduce_dynamic_multicore/caller.py new file mode 100644 index 00000000..e22aa443 --- /dev/null +++ b/tests/npu/reduce_dynamic_multicore/caller.py @@ -0,0 +1,49 @@ +"""Generate caller.cpp for dynamic multicore row/col reduction kernels. + +Usage: + python caller.py --mode rowsum|rowmin|rowmax|rowprod|colsum|colmin|colmax|colprod +""" + + +def generate_caller(mode, dtype): + ctype = "half" if dtype == "fp16" else "float" + return f"""\ +#include "{mode}.cpp" + +extern "C" void call_{mode}( + uint32_t blockDim, + void *stream, + uint8_t *x, + uint8_t *y, + uint32_t batch, + uint32_t n_cols) +{{ + _kernel<<>>( + reinterpret_cast<{ctype} *>(x), + reinterpret_cast<{ctype} *>(y), + static_cast(batch), + static_cast(n_cols)); +}} +""" + + +if __name__ == "__main__": + import argparse + + MODES = [ + "rowsum", + "rowmin", + "rowmax", + # "rowprod", + "colsum", + "colmin", + "colmax", + # "colprod", + ] + + parser = argparse.ArgumentParser() + parser.add_argument("--mode", choices=MODES, required=True) + parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp32") + args = parser.parse_args() + + print(generate_caller(args.mode, args.dtype)) diff --git a/tests/npu/reduce_dynamic_multicore/compile.sh b/tests/npu/reduce_dynamic_multicore/compile.sh new file mode 100755 index 00000000..eae6231b --- /dev/null +++ b/tests/npu/reduce_dynamic_multicore/compile.sh @@ -0,0 +1,45 @@ +#!/bin/bash +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +TMP=$(mktemp -d) +trap "rm -rf \"$TMP\"" EXIT + +BISHENG_FLAGS=( + -I${ASCEND_TOOLKIT_HOME}/include + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong + -xcce -Xhost-start -Xhost-end + -mllvm -cce-aicore-stack-size=0x8000 + -mllvm -cce-aicore-function-stack-size=0x8000 + -mllvm -cce-aicore-record-overflow=true + -mllvm -cce-aicore-addr-transform + -mllvm -cce-aicore-dcci-insert-for-scalar=false + --npu-arch=dav-2201 -DMEMORY_BASE + -std=gnu++17 +) + +MODES=( + rowsum + rowmin + rowmax + # rowprod + colsum + colmin + colmax + # colprod +) + +for MODE in "${MODES[@]}"; do + python "$SCRIPT_DIR/gen_ir.py" --mode "$MODE" > "$TMP/${MODE}.pto" + ptoas --enable-insert-sync "$TMP/${MODE}.pto" -o "$TMP/${MODE}.cpp" + + python "$SCRIPT_DIR/caller.py" --mode "$MODE" > "$TMP/${MODE}_caller.cpp" + + bisheng "${BISHENG_FLAGS[@]}" \ + "$TMP/${MODE}_caller.cpp" \ + -o "$SCRIPT_DIR/${MODE}_lib.so" + + echo "Built ${MODE}_lib.so successfully." +done diff --git a/tests/npu/reduce_dynamic_multicore/gen_ir.py b/tests/npu/reduce_dynamic_multicore/gen_ir.py new file mode 100644 index 00000000..d5c4ffe0 --- /dev/null +++ b/tests/npu/reduce_dynamic_multicore/gen_ir.py @@ -0,0 +1,45 @@ +"""Print MLIR IR for dynamic multicore row/col reduction kernels. + +Usage: + python gen_ir.py --mode rowsum|rowmin|rowmax|rowprod|colsum|colmin|colmax|colprod +""" + +import argparse +import os +import sys + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from reduce_builder import ( + build_colmax, + build_colmin, + # build_colprod, + build_colsum, + build_rowmax, + build_rowmin, + # build_rowprod, + build_rowsum, +) + +_BUILDERS = { + "rowsum": build_rowsum, + "rowmin": build_rowmin, + "rowmax": build_rowmax, + # "rowprod": build_rowprod, + "colsum": build_colsum, + "colmin": build_colmin, + "colmax": build_colmax, + # "colprod": build_colprod, +} + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + choices=list(_BUILDERS.keys()), + default="rowsum", + ) + parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp32") + args = parser.parse_args() + + print(_BUILDERS[args.mode](dtype=args.dtype)) diff --git a/tests/npu/reduce_dynamic_multicore/reduce_builder.py b/tests/npu/reduce_dynamic_multicore/reduce_builder.py new file mode 100644 index 00000000..17522ee8 --- /dev/null +++ b/tests/npu/reduce_dynamic_multicore/reduce_builder.py @@ -0,0 +1,361 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +# 32 KB of UB +_TILE_SIZE_BYTES = 32 * 1024 +_DTYPE_BYTES = {"fp16": 2, "fp32": 4} + + +def meta_data_row(dtype="fp32"): + pto_dtype = {"fp16": pto.float16, "fp32": pto.float32}[dtype] + elements_per_tile = _TILE_SIZE_BYTES // _DTYPE_BYTES[dtype] + ptr_type = pto.PtrType(pto_dtype) + index_dtype = pto.int32 + + tensor_type = pto.TensorType(rank=1, dtype=pto_dtype) + subtensor_in = pto.SubTensorType(shape=[1, elements_per_tile], dtype=pto_dtype) + + tile_cfg = pto.TileBufConfig() + tile_type = pto.TileBufType( + shape=[1, elements_per_tile], + valid_shape=[1, -1], + dtype=pto_dtype, + memory_space="VEC", + config=tile_cfg, + ) + + return { + "ptr_type": ptr_type, + "pto_dtype": pto_dtype, + "elements_per_tile": elements_per_tile, + "index_dtype": index_dtype, + "tensor_type": tensor_type, + "subtensor_in": subtensor_in, + "tile_type": tile_type, + } + + +def meta_data_col(dtype="fp32"): + pto_dtype = {"fp16": pto.float16, "fp32": pto.float32}[dtype] + ptr_type = pto.PtrType(pto_dtype) + index_dtype = pto.int32 + + tile_rows = 32 + tile_cols = 32 + + tensor2d_type = pto.TensorType(rank=2, dtype=pto_dtype) + subtensor_in = pto.SubTensorType(shape=[tile_rows, tile_cols], dtype=pto_dtype) + subtensor_out = pto.SubTensorType(shape=[1, tile_cols], dtype=pto_dtype) + + tile_cfg = pto.TileBufConfig() + + tile_type = pto.TileBufType( + shape=[tile_rows, tile_cols], + valid_shape=[-1, -1], + dtype=pto_dtype, + memory_space="VEC", + config=tile_cfg, + ) + + tile_out_type = pto.TileBufType( + shape=[1, tile_cols], + valid_shape=[1, -1], + dtype=pto_dtype, + memory_space="VEC", + config=tile_cfg, + ) + + return { + "ptr_type": ptr_type, + "pto_dtype": pto_dtype, + "index_dtype": index_dtype, + "tensor2d_type": tensor2d_type, + "subtensor_in": subtensor_in, + "subtensor_out": subtensor_out, + "tile_type": tile_type, + "tile_out_type": tile_out_type, + "tile_rows": tile_rows, + "tile_cols": tile_cols, + } + + +_ROW_REDUCE_OPS = { + "sum": tile.row_sum, + "min": tile.row_min, + "max": tile.row_max, + "prod": tile.row_prod, +} + +_COL_REDUCE_OPS = { + "sum": tile.col_sum, + "min": lambda src, tmp, dst: tile.col_min(src, dst), + "max": lambda src, tmp, dst: tile.col_max(src, dst), + "prod": tile.col_prod, +} + +_COL_COMBINE_OPS = { + "sum": tile.add, + "min": tile.min, + "max": tile.max, + "prod": tile.mul, +} + + +def build_row_reduce(kind="sum", dtype="fp32"): + """ + Generic row-wise reduction across columns. + + Semantics: + y[row] = reduce_j x[row, j] + """ + if kind not in _ROW_REDUCE_OPS: + raise ValueError(f"Unsupported row reduction kind: {kind}") + + row_reduce = _ROW_REDUCE_OPS[kind] + _meta_data = lambda: meta_data_row(dtype=dtype) + + @to_ir_module(meta_data=_meta_data) + def _kernel( + x_ptr: "ptr_type", + y_ptr: "ptr_type", + batch_i32: "index_dtype", + n_cols_i32: "index_dtype", + ) -> None: + c0 = const(0) + c1 = const(1) + + batch = s.index_cast(batch_i32) + n_cols = s.index_cast(n_cols_i32) + + with pto.vector_section(): + bid = s.index_cast(pto.get_block_idx()) + num_cores = s.index_cast(pto.get_block_num()) + + rows_per_core = s.ceil_div(batch, num_cores) + row_start = bid * rows_per_core + row_end = s.min_u(row_start + rows_per_core, batch) + num_rows = row_end - row_start + + total_elems = batch * n_cols + tv_x = pto.as_tensor( + tensor_type, ptr=x_ptr, shape=[total_elems], strides=[c1] + ) + tv_y = pto.as_tensor(tensor_type, ptr=y_ptr, shape=[batch], strides=[c1]) + + with pto.if_context(num_rows > c0): + tb_x = pto.alloc_tile(tile_type, valid_col=n_cols) + tb_out = pto.alloc_tile(tile_type, valid_col=c1) + tb_tmp = pto.alloc_tile(tile_type, valid_col=n_cols) + + for r in pto.range(c0, num_rows, c1): + gm_offset = (row_start + r) * n_cols + + sv_x = pto.slice_view( + subtensor_in, + source=tv_x, + offsets=[gm_offset], + sizes=[n_cols], + ) + + sv_y = pto.slice_view( + subtensor_in, + source=tv_y, + offsets=[row_start + r], + sizes=[c1], + ) + + pto.load(sv_x, tb_x) + row_reduce(tb_x, tb_tmp, tb_out) + pto.store(tb_out, sv_y) + + return _kernel + + +def build_col_reduce(kind="sum", dtype="fp32"): + """ + Generic column-wise reduction across rows. + + Semantics: + y[col] = reduce_i x[i, col] + """ + if kind not in _COL_REDUCE_OPS: + raise ValueError(f"Unsupported column reduction kind: {kind}") + + col_reduce = _COL_REDUCE_OPS[kind] + combine = _COL_COMBINE_OPS[kind] + _meta_data = lambda: meta_data_col(dtype=dtype) + + @to_ir_module(meta_data=_meta_data) + def _kernel( + x_ptr: "ptr_type", + y_ptr: "ptr_type", + batch_i32: "index_dtype", + n_cols_i32: "index_dtype", + ) -> None: + c0 = const(0) + c1 = const(1) + c_tile_rows = const(tile_rows) + c_tile_cols = const(tile_cols) + + batch = s.index_cast(batch_i32) + n_cols = s.index_cast(n_cols_i32) + + with pto.vector_section(): + bid = s.index_cast(pto.get_block_idx()) + num_cores = s.index_cast(pto.get_block_num()) + + cols_per_core = s.ceil_div(n_cols, num_cores) + col_start = bid * cols_per_core + col_end = s.min_u(col_start + cols_per_core, n_cols) + num_cols = col_end - col_start + + tv_x = pto.as_tensor( + tensor2d_type, + ptr=x_ptr, + shape=[batch, n_cols], + strides=[n_cols, c1], + ) + tv_y = pto.as_tensor( + tensor2d_type, + ptr=y_ptr, + shape=[c1, n_cols], + strides=[n_cols, c1], + ) + for col in pto.range(col_start, col_end, c_tile_cols): + cols_this = s.min_u(c_tile_cols, col_end - col) + rows_this0 = s.min_u(c_tile_rows, batch) + + tb_x0 = pto.alloc_tile( + tile_type, + valid_row=rows_this0, + valid_col=cols_this, + ) + tb_tmp0 = pto.alloc_tile( + tile_type, + valid_row=rows_this0, + valid_col=cols_this, + ) + if kind in {"min", "max"}: + tb_acc = pto.alloc_tile( + tile_type, valid_row=c1, valid_col=cols_this + ) + else: + tb_acc = pto.alloc_tile(tile_out_type, valid_col=cols_this) + + sv_x0 = pto.slice_view( + subtensor_in, + source=tv_x, + offsets=[c0, col], + sizes=[rows_this0, cols_this], + ) + pto.load(sv_x0, tb_x0) + col_reduce(tb_x0, tb_tmp0, tb_acc) + + for row in pto.range(c_tile_rows, batch, c_tile_rows): + rows_this = s.min_u(c_tile_rows, batch - row) + + tb_x = pto.alloc_tile( + tile_type, + valid_row=rows_this, + valid_col=cols_this, + ) + tb_tmp = pto.alloc_tile( + tile_type, + valid_row=rows_this, + valid_col=cols_this, + ) + if kind in {"min", "max"}: + tb_part = pto.alloc_tile( + tile_type, valid_row=c1, valid_col=cols_this + ) + else: + tb_part = pto.alloc_tile(tile_out_type, valid_col=cols_this) + + sv_x = pto.slice_view( + subtensor_in, + source=tv_x, + offsets=[row, col], + sizes=[rows_this, cols_this], + ) + pto.load(sv_x, tb_x) + col_reduce(tb_x, tb_tmp, tb_part) + combine(tb_acc, tb_part, tb_acc) + + sv_y = pto.slice_view( + subtensor_out, + source=tv_y, + offsets=[c0, col], + sizes=[c1, cols_this], + ) + pto.store(tb_acc, sv_y) + + return _kernel + + +def build_rowsum(dtype="fp32"): + return build_row_reduce("sum", dtype=dtype) + + +def build_rowmin(dtype="fp32"): + return build_row_reduce("min", dtype=dtype) + + +def build_rowmax(dtype="fp32"): + return build_row_reduce("max", dtype=dtype) + + +def build_rowprod(dtype="fp32"): + return build_row_reduce("prod", dtype=dtype) + + +def build_colsum(dtype="fp32"): + return build_col_reduce("sum", dtype=dtype) + + +def build_colmin(dtype="fp32"): + return build_col_reduce("min", dtype=dtype) + + +def build_colmax(dtype="fp32"): + return build_col_reduce("max", dtype=dtype) + + +def build_colprod(dtype="fp32"): + return build_col_reduce("prod", dtype=dtype) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument( + "--mode", + choices=[ + "rowsum", + "rowmin", + "rowmax", + "rowprod", + "colsum", + "colmin", + "colmax", + "colprod", + ], + default="rowsum", + ) + parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp32") + args = parser.parse_args() + + builders = { + "rowsum": build_rowsum, + "rowmin": build_rowmin, + "rowmax": build_rowmax, + "rowprod": build_rowprod, + "colsum": build_colsum, + "colmin": build_colmin, + "colmax": build_colmax, + "colprod": build_colprod, + } + + print(builders[args.mode](dtype=args.dtype)) diff --git a/tests/npu/reduce_dynamic_multicore/test_reduce.py b/tests/npu/reduce_dynamic_multicore/test_reduce.py new file mode 100644 index 00000000..c420f06e --- /dev/null +++ b/tests/npu/reduce_dynamic_multicore/test_reduce.py @@ -0,0 +1,156 @@ +import ctypes +import os +import subprocess + +import pytest +import torch + +from ptodsl.test_util import get_test_device + +torch.manual_seed(0) + +_DIR = os.path.dirname(os.path.abspath(__file__)) +_DEVICE = get_test_device() +_BLOCK_DIM = 24 + +_KERNELS = [ + "rowsum", + "rowmin", + "rowmax", + # "rowprod", + "colsum", + "colmin", + "colmax", + # "colprod", +] + +_LIB_PATHS = {name: os.path.join(_DIR, f"{name}_lib.so") for name in _KERNELS} + +_SHAPES = [ + (1, 1), + (2, 3), + (3, 2), + (7, 7), + (15, 17), + (17, 15), + (31, 32), + (32, 31), + (32, 32), + (33, 33), + (31, 33), + (33, 31), + (64, 32), + (32, 64), + (63, 64), + (64, 63), + (65, 33), + (65, 64), + (29, 257), + (127, 129), +] + +_SHAPE_IDS = [f"batch{batch}-cols{n_cols}" for batch, n_cols in _SHAPES] + + +@pytest.fixture(scope="session") +def compiled_kernels(): + subprocess.check_call(["bash", os.path.join(_DIR, "compile.sh")], cwd=_DIR) + yield + for path in _LIB_PATHS.values(): + if os.path.exists(path): + os.remove(path) + + +def _load_kernel(name): + lib = ctypes.CDLL(_LIB_PATHS[name]) + fn = getattr(lib, f"call_{name}") + fn.argtypes = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_uint32, + ctypes.c_uint32, + ] + fn.restype = None + return fn + + +def _reference(name, x): + if name == "rowsum": + return x.float().sum(dim=-1) + if name == "rowmin": + return x.float().amin(dim=-1) + if name == "rowmax": + return x.float().amax(dim=-1) + if name == "rowprod": + return x.float().prod(dim=-1) + if name == "colsum": + return x.float().sum(dim=0) + if name == "colmin": + return x.float().amin(dim=0) + if name == "colmax": + return x.float().amax(dim=0) + if name == "colprod": + return x.float().prod(dim=0) + raise ValueError(f"Unknown kernel: {name}") + + +def _output_shape(name, batch, n_cols): + return (batch,) if name.startswith("row") else (n_cols,) + + +def _make_input(name, batch, n_cols, device): + if name.endswith("prod"): + return torch.empty(batch, n_cols, device=device, dtype=torch.float32).uniform_( + 0.5, 1.5 + ) + return torch.randn(batch, n_cols, device=device, dtype=torch.float32) + + +def _tolerances(name): + if name.endswith("prod"): + return {"atol": 1e-3, "rtol": 1e-3} + return {"atol": 1e-4, "rtol": 0} + + +@pytest.mark.parametrize("name", _KERNELS) +def test_build_kernel(compiled_kernels, name): + assert os.path.exists(_LIB_PATHS[name]) + + +@pytest.mark.require_npu +@pytest.mark.parametrize("name", _KERNELS) +@pytest.mark.parametrize("batch, n_cols", _SHAPES, ids=_SHAPE_IDS) +def test_kernel_precision(compiled_kernels, name, batch, n_cols): + import torch_npu # noqa: F401 + + torch.npu.set_device(_DEVICE) + + fn = _load_kernel(name) + + x = _make_input(name, batch, n_cols, _DEVICE) + y = torch.full( + _output_shape(name, batch, n_cols), + float("nan"), + device=_DEVICE, + dtype=torch.float32, + ) + y_ref = _reference(name, x) + + stream_ptr = torch.npu.current_stream()._as_parameter_ + fn( + ctypes.c_uint32(_BLOCK_DIM), + stream_ptr, + ctypes.c_void_p(x.data_ptr()), + ctypes.c_void_p(y.data_ptr()), + ctypes.c_uint32(batch), + ctypes.c_uint32(n_cols), + ) + torch.npu.synchronize() + + torch.testing.assert_close(y, y_ref, **_tolerances(name)) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/npu/sort32_dynamic_multicore/builder.py b/tests/npu/sort32_dynamic_multicore/builder.py new file mode 100644 index 00000000..a6c419eb --- /dev/null +++ b/tests/npu/sort32_dynamic_multicore/builder.py @@ -0,0 +1,194 @@ +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +# TSORT32 sorts within fixed 32-element blocks. +# Each input element expands into (score, index) pairs in the output: +# float16: 4 float16 words [score_f16, zero, idx_lo_u16, idx_hi_u16] +# float32: 2 float32 words [score_f32, idx_u32] +_SORT_BLOCK_LEN = 32 + +_DTYPES = { + "float16": lambda: pto.float16, + "float32": lambda: pto.float32, +} + +# Output words per input element (in units of the src dtype) +_DST_STRIDE = { + "float16": 4, + "float32": 2, +} + + +def meta_data(dtype="float16", tile_length=1024): + if isinstance(dtype, str): + dtype_str = dtype + pto_dtype = _DTYPES[dtype]() + else: + pto_dtype = dtype + dtype_str = "float16" + + dst_stride = _DST_STRIDE[dtype_str] + u32 = pto.uint32 + dst_tile_length = tile_length * dst_stride + + tile_cfg = pto.TileBufConfig() + return { + "ptr_src": pto.PtrType(pto_dtype), + "ptr_u32": pto.PtrType(u32), + "ptr_dst": pto.PtrType(pto_dtype), + "index_dtype": pto.int32, + "tensor_src": pto.TensorType(rank=2, dtype=pto_dtype), + "tensor_u32": pto.TensorType(rank=2, dtype=u32), + "tensor_dst": pto.TensorType(rank=2, dtype=pto_dtype), + "subtensor_src": pto.SubTensorType(shape=[1, tile_length], dtype=pto_dtype), + "subtensor_u32": pto.SubTensorType(shape=[1, tile_length], dtype=u32), + "subtensor_dst": pto.SubTensorType(shape=[1, dst_tile_length], dtype=pto_dtype), + "tile_src": pto.TileBufType( + shape=[1, tile_length], + valid_shape=[1, tile_length], + dtype=pto_dtype, + memory_space="VEC", + config=tile_cfg, + ), + "tile_u32": pto.TileBufType( + shape=[1, tile_length], + valid_shape=[1, tile_length], + dtype=u32, + memory_space="VEC", + config=tile_cfg, + ), + "tile_dst": pto.TileBufType( + shape=[1, dst_tile_length], + valid_shape=[1, dst_tile_length], + dtype=pto_dtype, + memory_space="VEC", + config=tile_cfg, + ), + "tile_length": tile_length, + "dst_tile_length": dst_tile_length, + } + + +def build_tsort32_kernel( + fn_name="tsort32_1d_dynamic_float16", + dtype="float16", + tile_length=1024, +): + """Build a 1D dynamic multicore TSORT32 kernel. + + For each tile of tile_length elements: + - Reads src (scores) and idx (uint32 indices). + - Calls TSORT32, which sorts within _SORT_BLOCK_LEN-element blocks and + writes interleaved (score, index) pairs to dst. + - dst is dst_stride times wider than src in same-dtype words: + float16: dst_stride=4 → dst is float16[N * 4] + float32: dst_stride=2 → dst is float32[N * 2] + + Constraints: + - tile_length must be a multiple of _SORT_BLOCK_LEN (32) + - N (total input elements) must be a multiple of tile_length + """ + assert ( + tile_length % _SORT_BLOCK_LEN == 0 + ), f"tile_length must be a multiple of {_SORT_BLOCK_LEN}, got {tile_length}" + dtype_str = dtype if isinstance(dtype, str) else "float16" + dst_stride = _DST_STRIDE[dtype_str] + dst_tile_length = tile_length * dst_stride + _meta_data = lambda: meta_data(dtype=dtype, tile_length=tile_length) + + def _kernel( + arg_src: "ptr_src", # input scores [N] + arg_idx: "ptr_u32", # uint32 input indices [N] + arg_dst: "ptr_dst", # output pairs [N * dst_stride] + argN: "index_dtype", # total input elements (multiple of tile_length) + ) -> None: + c0 = const(0) + c1 = const(1) + c_tile = const(tile_length) + c_dst_tile = const(dst_tile_length) + + total_elements = s.index_cast(argN) + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + vid = cid * sub_bnum + sub_bid + num_blocks = pto.get_block_num() + + vid_idx = s.index_cast(vid) + num_cores = s.index_cast(num_blocks * sub_bnum) + + num_tiles_global = s.ceil_div(total_elements, c_tile) + num_tiles_per_core = s.ceil_div(num_tiles_global, num_cores) + tile_offset_this_core = vid_idx * num_tiles_per_core + + with pto.vector_section(): + tv_src = pto.as_tensor( + tensor_src, + ptr=arg_src, + shape=[num_tiles_global, c_tile], + strides=[c_tile, c1], + ) + tv_idx = pto.as_tensor( + tensor_u32, + ptr=arg_idx, + shape=[num_tiles_global, c_tile], + strides=[c_tile, c1], + ) + tv_dst = pto.as_tensor( + tensor_dst, + ptr=arg_dst, + shape=[num_tiles_global, c_dst_tile], + strides=[c_dst_tile, c1], + ) + + tb_src = pto.alloc_tile(tile_src) + tb_idx = pto.alloc_tile(tile_u32) + tb_dst = pto.alloc_tile(tile_dst) + + with pto.if_context(tile_offset_this_core < num_tiles_global): + tiles_end_this_core = tile_offset_this_core + num_tiles_per_core + need_truncate = tiles_end_this_core > num_tiles_global + remaining_tiles = num_tiles_global - tile_offset_this_core + tiles_to_process = s.select( + need_truncate, remaining_tiles, num_tiles_per_core + ) + + with pto.if_context(tiles_to_process > c0): + for i in pto.range(c0, tiles_to_process, c1): + ti = i + tile_offset_this_core + + sv_src = pto.slice_view( + subtensor_src, + source=tv_src, + offsets=[ti, c0], + sizes=[c1, c_tile], + ) + sv_idx = pto.slice_view( + subtensor_u32, + source=tv_idx, + offsets=[ti, c0], + sizes=[c1, c_tile], + ) + sv_dst = pto.slice_view( + subtensor_dst, + source=tv_dst, + offsets=[ti, c0], + sizes=[c1, c_dst_tile], + ) + + pto.load(sv_src, tb_src) + pto.load(sv_idx, tb_idx) + tile.sort32(tb_src, tb_dst, tb_idx) + pto.store(tb_dst, sv_dst) + + _kernel.__name__ = fn_name + return to_ir_module(meta_data=_meta_data)(_kernel) + + +if __name__ == "__main__": + import sys + + dtype = sys.argv[1] if len(sys.argv) > 1 else "float16" + print(build_tsort32_kernel(dtype=dtype)) diff --git a/tests/npu/sort32_dynamic_multicore/caller.py b/tests/npu/sort32_dynamic_multicore/caller.py new file mode 100644 index 00000000..1a092dc5 --- /dev/null +++ b/tests/npu/sort32_dynamic_multicore/caller.py @@ -0,0 +1,36 @@ +"""Generate caller.cpp for the dynamic multicore TSORT32 kernel.""" + +import sys + +_DTYPE_TO_CTYPE = { + "float16": "half", + "float32": "float", +} + +_BLOCK_DIM = 24 + + +def fn_name(dtype): + return f"tsort32_1d_dynamic_{dtype}" + + +def generate_caller(dtype): + ctype = _DTYPE_TO_CTYPE[dtype] + fn = fn_name(dtype) + return f"""\ +#include "{fn}.cpp" + +extern "C" void call_{fn}( + void *stream, uint8_t *src, uint8_t *idx, uint8_t *dst, int32_t N) +{{ + {fn}<<<{_BLOCK_DIM}, nullptr, stream>>>( + ({ctype} *)src, (uint32_t *)idx, ({ctype} *)dst, (int32_t)N); +}} +""" + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python caller.py ", file=sys.stderr) + sys.exit(1) + print(generate_caller(sys.argv[1])) diff --git a/tests/npu/sort32_dynamic_multicore/compile.sh b/tests/npu/sort32_dynamic_multicore/compile.sh new file mode 100755 index 00000000..1e21d2c0 --- /dev/null +++ b/tests/npu/sort32_dynamic_multicore/compile.sh @@ -0,0 +1,34 @@ +#!/bin/bash +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" + +DTYPE=${1:?Usage: compile.sh } + +FN_NAME="tsort32_1d_dynamic_${DTYPE}" + +TMP=$(mktemp -d) +trap "rm -rf $TMP" EXIT + +python "$SCRIPT_DIR/gen_ir.py" "$DTYPE" > "$TMP/${FN_NAME}.pto" +ptoas --enable-insert-sync "$TMP/${FN_NAME}.pto" -o "$TMP/${FN_NAME}.cpp" + +python "$SCRIPT_DIR/caller.py" "$DTYPE" > "$TMP/caller.cpp" + +PTO_LIB_PATH=/sources/pto-isa +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + "$TMP/caller.cpp" \ + -o "$SCRIPT_DIR/${FN_NAME}_lib.so" + +echo "Built ${FN_NAME}_lib.so successfully." diff --git a/tests/npu/sort32_dynamic_multicore/gen_ir.py b/tests/npu/sort32_dynamic_multicore/gen_ir.py new file mode 100644 index 00000000..a2632f69 --- /dev/null +++ b/tests/npu/sort32_dynamic_multicore/gen_ir.py @@ -0,0 +1,28 @@ +"""Print MLIR IR for the dynamic multicore TSORT32 kernel. + +Usage: python gen_ir.py +""" + +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from builder import build_tsort32_kernel + +TILE_LENGTH = 1024 + + +def fn_name(dtype): + return f"tsort32_1d_dynamic_{dtype}" + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print("Usage: python gen_ir.py ", file=sys.stderr) + sys.exit(1) + dtype = sys.argv[1] + module = build_tsort32_kernel( + fn_name=fn_name(dtype), dtype=dtype, tile_length=TILE_LENGTH + ) + print(module) diff --git a/tests/npu/sort32_dynamic_multicore/test_tsort32.py b/tests/npu/sort32_dynamic_multicore/test_tsort32.py new file mode 100644 index 00000000..453db5bc --- /dev/null +++ b/tests/npu/sort32_dynamic_multicore/test_tsort32.py @@ -0,0 +1,146 @@ +import os +import ctypes +import subprocess + +import pytest +import torch +from ptodsl.test_util import get_test_device + +torch.manual_seed(0) + +_DIR = os.path.dirname(os.path.abspath(__file__)) +_DEVICE = get_test_device() + +# TSORT32 sorts within fixed 32-element blocks. +# Each input element expands into (score, index) pairs in the output: +# float16: dst_stride=4 → [score_f16, zero, idx_lo_u16, idx_hi_u16] +# float32: dst_stride=2 → [score_f32, idx_u32] +# tile_length must be a multiple of SORT_BLOCK_LEN. +TILE_LENGTH = 1024 +SORT_BLOCK_LEN = 32 +DTYPES = ["float16", "float32"] +SIZES = [1024, 2048, 3072, 4096, 6144, 8192, 16384] + +_DST_STRIDE = {"float16": 4, "float32": 2} +_TORCH_DTYPES = {"float16": torch.float16, "float32": torch.float32} + +_DTYPE_PARAMS = [pytest.param(dtype, id=dtype) for dtype in DTYPES] +_SIZE_PARAMS = [pytest.param(N, id=f"N{N}") for N in SIZES] + + +def _fn_name(dtype): + return f"tsort32_1d_dynamic_{dtype}" + + +def _lib_path(dtype): + return os.path.join(_DIR, f"{_fn_name(dtype)}_lib.so") + + +def _ctypes_ptr(tensor: torch.Tensor): + return ctypes.c_void_p(tensor.data_ptr()) + + +@pytest.fixture(scope="session", params=_DTYPE_PARAMS) +def compiled_lib(request): + dtype = request.param + subprocess.check_call( + ["bash", os.path.join(_DIR, "compile.sh"), dtype], + cwd=_DIR, + ) + yield {"dtype": dtype} + libp = _lib_path(dtype) + if os.path.exists(libp): + os.remove(libp) + + +def _load_fn(dtype): + lib = ctypes.CDLL(_lib_path(dtype)) + fn = getattr(lib, f"call_{_fn_name(dtype)}") + fn.argtypes = [ + ctypes.c_void_p, # stream + ctypes.c_void_p, # src + ctypes.c_void_p, # idx (uint32) + ctypes.c_void_p, # dst (N * dst_stride elements) + ctypes.c_int32, # N + ] + fn.restype = None + return fn + + +def _run_kernel( + fn, stream_ptr, src: torch.Tensor, idx: torch.Tensor, N: int, dst_stride: int +) -> torch.Tensor: + import torch_npu + + dst = torch.empty(N * dst_stride, dtype=src.dtype, device=src.device) + torch.npu.synchronize() + fn( + stream_ptr, + _ctypes_ptr(src), + _ctypes_ptr(idx), + _ctypes_ptr(dst), + ctypes.c_int32(N), + ) + torch.npu.synchronize() + return dst + + +def _check_preconditions(N: int): + assert ( + N % TILE_LENGTH == 0 + ), f"N must be a multiple of TILE_LENGTH={TILE_LENGTH}, got {N}" + assert TILE_LENGTH % SORT_BLOCK_LEN == 0 + + +def _extract_scores(dst: torch.Tensor, dst_stride: int) -> torch.Tensor: + """Slot 0 of each output group holds the sorted score.""" + return dst.cpu().reshape(-1, dst_stride)[:, 0] + + +def _reference_scores(src: torch.Tensor) -> torch.Tensor: + """Sort each SORT_BLOCK_LEN-element group descending.""" + return ( + src.cpu() + .reshape(-1, SORT_BLOCK_LEN) + .sort(dim=1, descending=True) + .values.reshape(-1) + ) + + +def test_build_tsort32(compiled_lib): + dtype = compiled_lib["dtype"] + assert os.path.exists(_lib_path(dtype)) + + +@pytest.mark.require_npu +@pytest.mark.parametrize("N", _SIZE_PARAMS) +def test_tsort32_scores(compiled_lib, N): + """Scores extracted from TSORT32 output match per-block sorted input.""" + import torch_npu + + dtype = compiled_lib["dtype"] + torch_dtype = _TORCH_DTYPES[dtype] + dst_stride = _DST_STRIDE[dtype] + + _check_preconditions(N) + torch.npu.set_device(_DEVICE) + + fn = _load_fn(dtype) + stream_ptr = torch.npu.current_stream()._as_parameter_ + + src = torch.rand(N, dtype=torch_dtype, device=_DEVICE) + idx = torch.arange(N, dtype=torch.int32, device=_DEVICE) + dst = _run_kernel(fn, stream_ptr, src, idx, N, dst_stride) + + scores_got = _extract_scores(dst, dst_stride) + scores_ref = _reference_scores(src) + + torch.testing.assert_close( + scores_got, + scores_ref, + msg="TSORT32 scores do not match per-block sorted reference", + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/regression/test_a5_lib_regression.py b/tests/regression/test_a5_lib_regression.py new file mode 100644 index 00000000..d9b23c7d --- /dev/null +++ b/tests/regression/test_a5_lib_regression.py @@ -0,0 +1,420 @@ +import pytest +from mlir.ir import IndexType + +from ptodsl import pto, to_ir_module +from ptodsl.lib import a5 +from scripts.generate_a5_pto import emit_kernels + + +def test_a5_elementwise_add_kernel_emits_tile_flow(): + text = str(a5.build_elementwise_add()) + + assert "func.func @a5_elementwise_add" in text + assert "pto.make_tensor_view" in text + assert "pto.tload" in text + assert "pto.vlds" in text + assert "pto.vadd" in text + assert "pto.vsts" in text + assert "pto.tadd" not in text + assert "pto.tstore" in text + + +def test_a5_templated_elementwise_add_specializes_constexpr_impl(): + specializer = a5.build_templated_elementwise_add() + text = str( + specializer( + ROWS=8, + COLS=64, + VF_IMPL=a5.VF_IMPL_1D_POST_UPDATE, + ) + ) + + assert "func.func @a5_templated_elementwise_add(%arg0" in text + assert "ROWS" not in text + assert "COLS" not in text + assert "VF_IMPL" not in text + assert "scf.if" not in text + assert "pto.vlds_post" in text + assert "pto.vsts_post" in text + assert "pto.tadd" not in text + + +def test_a5_micro_vector_copy_emits_micro_ops(): + text = str(a5.build_micro_vector_copy()) + + assert "func.func @a5_micro_vector_copy" in text + assert "pto.pset_b32" in text + assert "pto.vlds" in text + assert "pto.vsts" in text + + +def test_a5_col_expand_micro_emits_broadcast_micro_ops(): + def meta_data(): + return { + "ptr_t": pto.ptr(pto.float32), + "index_t": IndexType.get(), + } + + @to_ir_module(meta_data=meta_data) + def a5_col_expand_micro(src: "ptr_t", dst: "ptr_t") -> None: + src_view = pto.make_tensor(src, shape=[1, 32], dtype=pto.float32) + dst_view = pto.make_tensor(dst, shape=[32, 32], dtype=pto.float32) + with pto.vector_section(): + a5.col_expand_micro( + src_view.slice([0, 0], [1, 32]), + dst_view.slice([0, 0], [32, 32]), + dtype=pto.float32, + shape=[32, 32], + ) + + text = str(a5_col_expand_micro) + + assert "func.func @a5_col_expand_micro" in text + assert "pto.vlds" in text + assert "pto.vsts" in text + assert "pto.tcolexpand" not in text + + +def test_a5_gather_micro_emits_indexed_gather_micro_ops(): + def meta_data(): + return { + "ptr_src": pto.ptr(pto.float32), + "ptr_idx": pto.ptr(pto.uint32), + } + + @to_ir_module(meta_data=meta_data) + def a5_gather_micro(src: "ptr_src", idx: "ptr_idx", dst: "ptr_src") -> None: + src_view = pto.make_tensor(src, shape=[1, 64], dtype=pto.float32) + idx_view = pto.make_tensor(idx, shape=[1, 64], dtype=pto.uint32) + dst_view = pto.make_tensor(dst, shape=[1, 64], dtype=pto.float32) + with pto.vector_section(): + a5.gather_micro( + src_view.slice([0, 0], [1, 64]), + idx_view.slice([0, 0], [1, 64]), + dst_view.slice([0, 0], [1, 64]), + dtype=pto.float32, + index_dtype=pto.uint32, + shape=[1, 64], + ) + + text = str(a5_gather_micro) + + assert "func.func @a5_gather_micro" in text + assert "pto.vgather2" in text + assert "pto.vsts" in text + assert "pto.tgather" not in text + + +def test_a5_row_expand_micro_emits_broadcast_micro_ops(): + def meta_data(): + return { + "ptr_t": pto.ptr(pto.float32), + "index_t": IndexType.get(), + } + + @to_ir_module(meta_data=meta_data) + def a5_row_expand_micro(src: "ptr_t", dst: "ptr_t") -> None: + src_view = pto.make_tensor(src, shape=[32, 1], dtype=pto.float32) + dst_view = pto.make_tensor(dst, shape=[32, 32], dtype=pto.float32) + with pto.vector_section(): + a5.row_expand_micro( + src_view.slice([0, 0], [32, 1]), + dst_view.slice([0, 0], [32, 32]), + dtype=pto.float32, + shape=[32, 32], + ) + + text = str(a5_row_expand_micro) + + assert "func.func @a5_row_expand_micro" in text + assert "pto.vldas" in text + assert "pto.vldus" in text + assert "pto.vdup" in text + assert "pto.vsts" in text + assert "pto.trowexpand" not in text + + +def test_a5_row_expand_mul_micro_emits_broadcast_compute_micro_ops(): + def meta_data(): + return { + "ptr_t": pto.ptr(pto.float32), + "index_t": IndexType.get(), + } + + @to_ir_module(meta_data=meta_data) + def a5_row_expand_mul_micro(base: "ptr_t", scale: "ptr_t", dst: "ptr_t") -> None: + base_view = pto.make_tensor(base, shape=[32, 32], dtype=pto.float32) + scale_view = pto.make_tensor(scale, shape=[32, 1], dtype=pto.float32) + dst_view = pto.make_tensor(dst, shape=[32, 32], dtype=pto.float32) + with pto.vector_section(): + a5.row_expand_mul_micro( + base_view.slice([0, 0], [32, 32]), + scale_view.slice([0, 0], [32, 1]), + dst_view.slice([0, 0], [32, 32]), + dtype=pto.float32, + shape=[32, 32], + ) + + text = str(a5_row_expand_mul_micro) + + assert "func.func @a5_row_expand_mul_micro" in text + assert "pto.vldas" in text + assert "pto.vldus" in text + assert "pto.vdup" in text + assert "pto.vmul" in text + assert "pto.vsts" in text + assert "pto.trowexpandmul" not in text + + +def test_a5_rsqrt_micro_emits_vsqrt_then_vrec(): + def meta_data(): + return { + "ptr_t": pto.ptr(pto.float32), + } + + @to_ir_module(meta_data=meta_data) + def a5_rsqrt_micro(src: "ptr_t", dst: "ptr_t") -> None: + src_view = pto.make_tensor(src, shape=[1, 64], dtype=pto.float32) + dst_view = pto.make_tensor(dst, shape=[1, 64], dtype=pto.float32) + with pto.vector_section(): + a5.rsqrt_micro( + src_view.slice([0, 0], [1, 64]), + dst_view.slice([0, 0], [1, 64]), + dtype=pto.float32, + shape=[1, 64], + ) + + text = str(a5_rsqrt_micro) + + assert "func.func @a5_rsqrt_micro" in text + assert "pto.vsqrt" in text + assert "pto.vrec" in text + assert "pto.trsqrt" not in text + + +@pytest.mark.parametrize( + ("helper_name", "reduce_op", "combine_op", "tile_op"), + [ + ("row_sum_micro", "pto.vcadd", "pto.vadd", "pto.trowsum"), + ("row_max_micro", "pto.vcmax", "pto.vmax", "pto.trowmax"), + ("row_min_micro", "pto.vcmin", "pto.vmin", "pto.trowmin"), + ], +) +def test_a5_row_reduce_micro_emits_reduction_micro_ops( + helper_name, reduce_op, combine_op, tile_op +): + def meta_data(): + return { + "ptr_t": pto.ptr(pto.float32), + "index_t": IndexType.get(), + } + + helper = getattr(a5, helper_name) + + @to_ir_module(meta_data=meta_data) + def a5_row_reduce_micro(src: "ptr_t", dst: "ptr_t") -> None: + src_view = pto.make_tensor(src, shape=[32, 32], dtype=pto.float32) + dst_view = pto.make_tensor(dst, shape=[32, 1], dtype=pto.float32) + with pto.vector_section(): + helper( + src_view.slice([0, 0], [32, 32]), + dst_view.slice([0, 0], [32, 1]), + dtype=pto.float32, + shape=[32, 32], + ) + + text = str(a5_row_reduce_micro) + + assert reduce_op in text + assert combine_op in text + assert 'dist = "ONEPT_B32"' in text + assert tile_op not in text + + +@pytest.mark.parametrize( + ("helper_name", "reduce_op", "tile_op", "impl"), + [ + ("col_sum_micro", "pto.vadd", "pto.tcolsum", a5.VF_IMPL_1D_POST_UPDATE), + ("col_max_micro", "pto.vmax", "pto.tcolmax", a5.VF_IMPL_1D_NO_POST_UPDATE), + ("col_min_micro", "pto.vmin", "pto.tcolmin", a5.VF_IMPL_1D_POST_UPDATE), + ], +) +def test_a5_col_reduce_micro_emits_template_lowering( + helper_name, reduce_op, tile_op, impl +): + def meta_data(): + return { + "ptr_t": pto.ptr(pto.float32), + } + + helper = getattr(a5, helper_name) + + @to_ir_module(meta_data=meta_data) + def a5_col_reduce_micro(src: "ptr_t", dst: "ptr_t") -> None: + src_view = pto.make_tensor(src, shape=[32, 32], dtype=pto.float32) + dst_view = pto.make_tensor(dst, shape=[1, 32], dtype=pto.float32) + with pto.vector_section(): + helper( + src_view.slice([0, 0], [32, 32]), + dst_view.slice([0, 0], [1, 32]), + dtype=pto.float32, + shape=[32, 32], + impl=impl, + ) + + text = str(a5_col_reduce_micro) + + assert reduce_op in text + assert tile_op not in text + if impl == a5.VF_IMPL_1D_POST_UPDATE: + assert "pto.vlds_post" in text + assert "pto.vsts_post" in text + + +def test_a5_sort32_micro_emits_vbitsort(): + def meta_data(): + return { + "ptr_src": pto.ptr(pto.float32), + "ptr_idx": pto.ptr(pto.uint32), + } + + @to_ir_module(meta_data=meta_data) + def a5_sort32_micro(src: "ptr_src", idx: "ptr_idx", dst: "ptr_src") -> None: + src_view = pto.make_tensor(src, shape=[1, 64], dtype=pto.float32) + idx_view = pto.make_tensor(idx, shape=[1, 64], dtype=pto.uint32) + dst_view = pto.make_tensor(dst, shape=[1, 128], dtype=pto.float32) + with pto.vector_section(): + a5.sort32_micro( + src_view.slice([0, 0], [1, 64]), + idx_view.slice([0, 0], [1, 64]), + dst_view.slice([0, 0], [1, 128]), + dtype=pto.float32, + shape=[1, 64], + ) + + text = str(a5_sort32_micro) + + assert "func.func @a5_sort32_micro" in text + assert "pto.vbitsort" in text + assert "pto.tsort32" not in text + + +def test_a5_mrgsort_micro_emits_vmrgsort4(): + def meta_data(): + return {"ptr_t": pto.ptr(pto.float32)} + + @to_ir_module(meta_data=meta_data) + def a5_mrgsort_micro(src: "ptr_t", dst: "ptr_t") -> None: + src_view = pto.make_tensor(src, shape=[1, 256], dtype=pto.float32) + dst_view = pto.make_tensor(dst, shape=[1, 256], dtype=pto.float32) + with pto.vector_section(): + a5.mrgsort_micro( + src_view.slice([0, 0], [1, 256]), + dst_view.slice([0, 0], [1, 256]), + dtype=pto.float32, + shape=[1, 256], + block_len=64, + ) + + text = str(a5_mrgsort_micro) + + assert "func.func @a5_mrgsort_micro" in text + assert "pto.vmrgsort4" in text + assert "pto.tmrgsort" not in text + + +def test_a5_generation_script_emits_pto_files(tmp_path): + generated = emit_kernels(output_dir=tmp_path) + + generated_names = sorted(path.name for path in generated) + assert generated_names == [ + "a5_cube_matmul.pto", + "a5_elementwise_add.pto", + "a5_micro_vector_copy.pto", + ] + + for path in generated: + text = path.read_text(encoding="utf-8") + assert "func.func @" in text + + +def test_a5_add_micro_rejects_view_dtype_mismatch(): + def meta_data(): + return {"ptr_t": pto.ptr(pto.float16)} + + with pytest.raises( + ValueError, match="TADD input tile src0, src1 and dst tile data type mismatch" + ): + + @to_ir_module(meta_data=meta_data) + def invalid_add(src0: "ptr_t", src1: "ptr_t", dst: "ptr_t") -> None: + lhs = pto.make_tensor(src0, shape=[32, 32], dtype=pto.float16) + rhs = pto.make_tensor(src1, shape=[32, 32], dtype=pto.float16) + out = pto.make_tensor(dst, shape=[32, 32], dtype=pto.float16) + with pto.vector_section(): + a5.add_micro( + lhs.slice([0, 0], [32, 32]), + rhs.slice([0, 0], [32, 32]), + out.slice([0, 0], [32, 32]), + dtype=pto.float32, + shape=[32, 32], + ) + + +def test_a5_row_expand_micro_rejects_non_column_source(): + def meta_data(): + return {"ptr_t": pto.ptr(pto.float32)} + + with pytest.raises( + ValueError, match="TROWEXPAND source valid shape must be \\[rows, 1\\]" + ): + + @to_ir_module(meta_data=meta_data) + def invalid_row_expand(src: "ptr_t", dst: "ptr_t") -> None: + src_view = pto.make_tensor(src, shape=[1, 32], dtype=pto.float32) + dst_view = pto.make_tensor(dst, shape=[32, 32], dtype=pto.float32) + with pto.vector_section(): + a5.row_expand_micro( + src_view.slice([0, 0], [1, 32]), + dst_view.slice([0, 0], [32, 32]), + dtype=pto.float32, + shape=[32, 32], + ) + + +def test_a5_row_reduce_micro_rejects_non_single_column_output(): + def meta_data(): + return {"ptr_t": pto.ptr(pto.float32)} + + with pytest.raises(ValueError, match="use a single-column output tile"): + + @to_ir_module(meta_data=meta_data) + def invalid_row_reduce(src: "ptr_t", dst: "ptr_t") -> None: + src_view = pto.make_tensor(src, shape=[32, 32], dtype=pto.float32) + dst_view = pto.make_tensor(dst, shape=[1, 32], dtype=pto.float32) + with pto.vector_section(): + a5.row_sum_micro( + src_view.slice([0, 0], [32, 32]), + dst_view.slice([0, 0], [1, 32]), + dtype=pto.float32, + shape=[32, 32], + ) + + +def test_a5_col_reduce_micro_rejects_unsupported_dtype(): + def meta_data(): + return {"ptr_t": pto.ptr(pto.bool)} + + with pytest.raises(ValueError, match="TCOLREDUCE input data type is not supported"): + + @to_ir_module(meta_data=meta_data) + def invalid_col_reduce(src: "ptr_t", dst: "ptr_t") -> None: + src_view = pto.make_tensor(src, shape=[32, 32], dtype=pto.bool) + dst_view = pto.make_tensor(dst, shape=[1, 32], dtype=pto.bool) + with pto.vector_section(): + a5.col_sum_micro( + src_view.slice([0, 0], [32, 32]), + dst_view.slice([0, 0], [1, 32]), + dtype=pto.bool, + shape=[32, 32], + ) diff --git a/tests/regression/test_tile_micro_coverage.py b/tests/regression/test_tile_micro_coverage.py new file mode 100644 index 00000000..81dbe74c --- /dev/null +++ b/tests/regression/test_tile_micro_coverage.py @@ -0,0 +1,45 @@ +from pathlib import Path + +from ptodsl import tile +from ptodsl.lib import a5 +from ptodsl.lib.a5.tile_micro_coverage import ( + TILE_MICRO_COVERAGE, + coverage_markdown, + coverage_summary, +) + + +def test_tile_micro_coverage_checklist_covers_every_tile_api_symbol(): + assert set(TILE_MICRO_COVERAGE) == set(tile.__all__) + + +def test_implemented_tile_micro_helpers_exist(): + for name, entry in TILE_MICRO_COVERAGE.items(): + helper = entry["helper"] + if entry["status"] == "implemented": + assert helper is not None + assert getattr(a5, helper) is not None + + +def test_tile_micro_coverage_markdown_mentions_all_tile_ops(): + text = coverage_markdown() + for name in tile.__all__: + assert f"`{name}`" in text + + +def test_tile_micro_coverage_summary_matches_public_surface(): + counts = coverage_summary() + assert sum(counts.values()) == len(tile.__all__) + assert counts["implemented"] > 0 + assert counts["blocked"] > 0 + + +def test_checked_in_tile_micro_checklist_is_in_sync(): + checklist = ( + Path(__file__).resolve().parents[2] + / "ptodsl" + / "lib" + / "a5" + / "TILE_MICRO_CHECKLIST.md" + ) + assert checklist.read_text(encoding="utf-8") == coverage_markdown()